# ProtoNN in Tensorflow

This is a simple notebook that illustrates the usage of Tensorflow implementation of ProtoNN. We are using the USPS dataset. Please refer to `fetch_usps.py` for more details on downloading the dataset.

In [1]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.

from __future__ import print_function
import sys
import os
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

#sys.path.insert(0, '../../')
# from edgeml.trainer.protoNNTrainer import ProtoNNTrainer
# from edgeml.graph.protoNN import ProtoNN
# import edgeml.utils as utils
# import helpermethods as helper
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV


sys.path.append(r"D:\programming\practice\research\protoNN\EdgeML\examples\tf\ProtoNN")
import helpermethods as helper

Instructions for updating:
non-resource variables are not supported in the long term


# Helper Methods

In [2]:
#helper methods
sys.path.insert(0, '../')
import argparse


def getModelSize(matrixList, sparcityList, expected=True, bytesPerVar=4):
    '''
    expected: Expected size according to the parameters set. The number of
        zeros could actually be more than that is required to satisfy the
        sparsity constraint.
    '''
    nnzList, sizeList, isSparseList = [], [], []
    hasSparse = False
    for i in range(len(matrixList)):
        A, s = matrixList[i], sparcityList[i]
        assert A.ndim == 2
        assert s >= 0
        assert s <= 1
        nnz, size, sparse = countnnZ(A, s, bytesPerVar=bytesPerVar)
        nnzList.append(nnz)
        sizeList.append(size)
        hasSparse = (hasSparse or sparse)

    totalnnZ = np.sum(nnzList)
    totalSize = np.sum(sizeList)
    if expected:
        return totalnnZ, totalSize, hasSparse
    numNonZero = 0
    totalSize = 0
    hasSparse = False
    for i in range(len(matrixList)):
        A, s = matrixList[i], sparcityList[i]
        numNonZero_ = np.count_nonzero(A)
        numNonZero += numNonZero_
        hasSparse = (hasSparse or (s < 0.5))
        if s <= 0.5:
            totalSize += numNonZero_ * 2 * bytesPerVar
        else:
            totalSize += A.size * bytesPerVar
    return numNonZero, totalSize, hasSparse


def getGamma(gammaInit, projectionDim, dataDim, numPrototypes, x_train):
    if gammaInit is None:
        print("Using median heuristic to estimate gamma.")
        gamma, W, B = medianHeuristic(x_train, projectionDim,
                                            numPrototypes)
        print("Gamma estimate is: %f" % gamma)
        return W, B, gamma
    return None, None, gammaInit


def preprocessData(dataDir,w):
    '''
    Loads data from the dataDir and does some initial preprocessing
    steps. Data is assumed to be contained in two files,
    train.npy and test.npy. Each containing a 2D numpy array of dimension
    [numberOfExamples, numberOfFeatures + 1]. The first column of each
    matrix is assumed to contain label information.

    For an N-Class problem, we assume the labels are integers from 0 through
    N-1.
    '''
    # Uncomment for usual training data
    # train = np.load(dataDir + '/train_'+str(w)+'.npy')
    # test = np.load(dataDir + '/test_'+str(w)+'.npy')
    # Uncomment for time domain training data
    train = np.load(dataDir + '/ttrain_'+str(w)+'.npy')
    test = np.load(dataDir + '/ttest_'+str(w)+'.npy')
    # Uncomment for 1 sensordrop training data
    # train = np.load(dataDir + '/train_'+str(w)+'.npy')
    # test = np.load(dataDir + '/test_'+str(w)+'.npy')

    dataDimension = int(train.shape[1]) - 1
    x_train = train[:, 1:dataDimension + 1]
    y_train_ = train[:, 0]
    x_test = test[:, 1:dataDimension + 1]
    y_test_ = test[:, 0]

    numClasses = max(y_train_) - min(y_train_) + 1
    numClasses = max(numClasses, max(y_test_) - min(y_test_) + 1)
    numClasses = int(numClasses)

    # mean-var
    mean = np.mean(x_train, 0)
    std = np.std(x_train, 0)
    std[std[:] < 0.000001] = 1
    x_train = (x_train - mean) / std
    x_test = (x_test - mean) / std

    # one hot y-train
    lab = y_train_.astype('uint8')
    lab = np.array(lab) - min(lab)
    lab_ = np.zeros((x_train.shape[0], numClasses))
    lab_[np.arange(x_train.shape[0]), lab] = 1
    y_train = lab_

    # one hot y-test
    lab = y_test_.astype('uint8')
    lab = np.array(lab) - min(lab)
    lab_ = np.zeros((x_test.shape[0], numClasses))
    lab_[np.arange(x_test.shape[0]), lab] = 1
    y_test = lab_

    return dataDimension, numClasses, x_train, y_train, x_test, y_test



def getProtoNNArgs():
    def checkIntPos(value):
        ivalue = int(value)
        if ivalue <= 0:
            raise argparse.ArgumentTypeError(
                "%s is an invalid positive int value" % value)
        return ivalue

    def checkIntNneg(value):
        ivalue = int(value)
        if ivalue < 0:
            raise argparse.ArgumentTypeError(
                "%s is an invalid non-neg int value" % value)
        return ivalue

    def checkFloatNneg(value):
        fvalue = float(value)
        if fvalue < 0:
            raise argparse.ArgumentTypeError(
                "%s is an invalid non-neg float value" % value)
        return fvalue

    def checkFloatPos(value):
        fvalue = float(value)
        if fvalue <= 0:
            raise argparse.ArgumentTypeError(
                "%s is an invalid positive float value" % value)
        return fvalue

    '''
    Parse protoNN commandline arguments
    '''
    parser = argparse.ArgumentParser(
        description='Hyperparameters for ProtoNN Algorithm')

    msg = 'Data directory containing train and test data. The '
    msg += 'data is assumed to be saved as 2-D numpy matrices with '
    msg += 'names `train.npy` and `test.npy`, of dimensions\n'
    msg += '\t[numberOfInstances, numberOfFeatures + 1].\n'
    msg += 'The first column of each file is assumed to contain label information.'
    msg += ' For a N-class problem, labels are assumed to be integers from 0 to'
    msg += ' N-1 (inclusive).'
    parser.add_argument('-d', '--data-dir', required=True, help=msg)
    parser.add_argument('-l', '--projection-dim', type=checkIntPos, default=10,
                        help='Projection Dimension.')
    parser.add_argument('-p', '--num-prototypes', type=checkIntPos, default=20,
                        help='Number of prototypes.')
    parser.add_argument('-g', '--gamma', type=checkFloatPos, default=None,
                        help='Gamma for Gaussian kernel. If not provided, ' +
                        'median heuristic will be used to estimate gamma.')

    parser.add_argument('-e', '--epochs', type=checkIntPos, default=100,
                        help='Total training epochs.')
    parser.add_argument('-b', '--batch-size', type=checkIntPos, default=32,
                        help='Batch size for each pass.')
    parser.add_argument('-r', '--learning-rate', type=checkFloatPos,
                        default=0.001,
                        help='Initial Learning rate for ADAM Optimizer.')

    parser.add_argument('-rW', type=float, default=0.000,
                        help='Coefficient for l2 regularizer for predictor' +
                        ' parameter W ' + '(default = 0.0).')
    parser.add_argument('-rB', type=float, default=0.00,
                        help='Coefficient for l2 regularizer for predictor' +
                        ' parameter B ' + '(default = 0.0).')
    parser.add_argument('-rZ', type=float, default=0.00,
                        help='Coefficient for l2 regularizer for predictor' +
                        'parameter Z ' +
                        '(default = 0.0).')

    parser.add_argument('-sW', type=float, default=1.000,
                        help='Sparsity constraint for predictor parameter W ' +
                        '(default = 1.0, i.e. dense matrix).')
    parser.add_argument('-sB', type=float, default=1.00,
                        help='Sparsity constraint for predictor parameter B ' +
                        '(default = 1.0, i.e. dense matrix).')
    parser.add_argument('-sZ', type=float, default=1.00,
                        help='Sparsity constraint for predictor parameter Z ' +
                        '(default = 1.0, i.e. dense matrix).')
    parser.add_argument('-pS', '--print-step', type=int, default=200,
                        help='The number of update steps between print ' +
                        'calls to console.')
    parser.add_argument('-vS', '--val-step', type=int, default=3,
                        help='The number of epochs between validation' +
                        'performance evaluation')
    return parser.parse_args()

# Utils 

In [3]:
#utils
import scipy.cluster
import scipy.spatial
import os


def medianHeuristic(data, projectionDimension, numPrototypes, W_init=None):
    '''
    This method can be used to estimate gamma for ProtoNN. An approximation to
    median heuristic is used here.
    1. First the data is collapsed into the projectionDimension by W_init. If
    W_init is not provided, it is initialized from a random normal(0, 1). Hence
    data normalization is essential.
    2. Prototype are computed by running a  k-means clustering on the projected
    data.
    3. The median distance is then estimated by calculating median distance
    between prototypes and projected data points.

    data needs to be [-1, numFeats]
    If using this method to initialize gamma, please use the W and B as well.

    TODO: Return estimate of Z (prototype labels) based on cluster centroids
    andand labels

    TODO: Clustering fails due to singularity error if projecting upwards

    W [dxd_cap]
    B [d_cap, m]
    returns gamma, W, B
    '''
    assert data.ndim == 2
    X = data
    featDim = data.shape[1]
    if projectionDimension > featDim:
        print("Warning: Projection dimension > feature dimension. Gamma")
        print("\t estimation due to median heuristic could fail.")
        print("\tTo retain the projection dataDimension, provide")
        print("\ta value for gamma.")

    if W_init is None:
        W_init = np.random.normal(size=[featDim, projectionDimension])
    W = W_init
    XW = np.matmul(X, W)
    assert XW.shape[1] == projectionDimension
    assert XW.shape[0] == len(X)
    # Requires [N x d_cap] data matrix of N observations of d_cap-dimension and
    # the number of centroids m. Returns, [n x d_cap] centroids and
    # elementwise center information.
    B, centers = scipy.cluster.vq.kmeans2(XW, numPrototypes)
    # Requires two matrices. Number of observations x dimension of observation
    # space. Distances[i,j] is the distance between XW[i] and B[j]
    distances = scipy.spatial.distance.cdist(XW, B, metric='euclidean')
    distances = np.reshape(distances, [-1])
    gamma = np.median(distances)
    gamma = 1 / (2.5 * gamma)
    return gamma.astype('float32'), W.astype('float32'), B.T.astype('float32')


def multiClassHingeLoss(logits, label, batch_th):
    '''
    MultiClassHingeLoss to match C++ Version - No TF internal version
    '''
    flatLogits = tf.reshape(logits, [-1, ])
    label_ = tf.argmax(label, 1)

    correctId = tf.range(0, batch_th) * label.shape[1] + label_
    correctLogit = tf.gather(flatLogits, correctId)

    maxLabel = tf.argmax(logits, 1)
    top2, _ = tf.nn.top_k(logits, k=2, sorted=True)

    wrongMaxLogit = tf.where(
        tf.equal(maxLabel, label_), top2[:, 1], top2[:, 0])

    return tf.reduce_mean(tf.nn.relu(1. + wrongMaxLogit - correctLogit))


def crossEntropyLoss(logits, label):
    '''
    Cross Entropy loss for MultiClass case in joint training for
    faster convergence
    '''
    return tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                   labels=tf.stop_gradient(label)))


def mean_absolute_error(logits, label):
    '''
    Function to compute the mean absolute error.
    '''
    return tf.reduce_mean(tf.abs(tf.subtract(logits, label)))


def hardThreshold(A, s):
    '''
    Hard thresholding function on Tensor A with sparsity s
    '''
    A_ = np.copy(A)
    A_ = A_.ravel()
    if len(A_) > 0:
        th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher')
        A_[np.abs(A_) < th] = 0.0
    A_ = A_.reshape(A.shape)
    return A_


def copySupport(src, dest):
    '''
    copy support of src tensor to dest tensor
    '''
    support = np.nonzero(src)
    dest_ = dest
    dest = np.zeros(dest_.shape)
    dest[support] = dest_[support]
    return dest


def countnnZ(A, s, bytesPerVar=4):
    '''
    Returns # of non-zeros and representative size of the tensor
    Uses dense for s >= 0.5 - 4 byte
    Else uses sparse - 8 byte
    '''
    params = 1
    hasSparse = False
    for i in range(0, len(A.shape)):
        params *= int(A.shape[i])
    if s < 0.5:
        nnZ = np.ceil(params * s)
        hasSparse = True
        return nnZ, nnZ * 2 * bytesPerVar, hasSparse
    else:
        nnZ = params
        return nnZ, nnZ * bytesPerVar, hasSparse


def getConfusionMatrix(predicted, target, numClasses):
    '''
    Returns a confusion matrix for a multiclass classification
    problem. `predicted` is a 1-D array of integers representing
    the predicted classes and `target` is the target classes.

    confusion[i][j]: Number of elements of class j
        predicted as class i
    Labels are assumed to be in range(0, numClasses)
    Use`printFormattedConfusionMatrix` to echo the confusion matrix
    in a user friendly form.
    '''
    assert(predicted.ndim == 1)
    assert(target.ndim == 1)
    arr = np.zeros([numClasses, numClasses])

    for i in range(len(predicted)):
        arr[predicted[i]][target[i]] += 1
    return arr


def printFormattedConfusionMatrix(matrix):
    '''
    Given a 2D confusion matrix, prints it in a human readable way.
    The confusion matrix is expected to be a 2D numpy array with
    square dimensions
    '''
    assert(matrix.ndim == 2)
    assert(matrix.shape[0] == matrix.shape[1])
    RECALL = 'Recall'
    PRECISION = 'PRECISION'
    print("|%s|" % ('True->'), end='')
    for i in range(matrix.shape[0]):
        print("%7d|" % i, end='')
    print("%s|" % 'Precision')

    print("|%s|" % ('-' * len(RECALL)), end='')
    for i in range(matrix.shape[0]):
        print("%s|" % ('-' * 7), end='')
    print("%s|" % ('-' * len(PRECISION)))

    precisionlist = np.sum(matrix, axis=1)
    recalllist = np.sum(matrix, axis=0)
    precisionlist = [matrix[i][i] / x if x !=
                     0 else -1 for i, x in enumerate(precisionlist)]
    recalllist = [matrix[i][i] / x if x !=
                  0 else -1 for i, x in enumerate(recalllist)]
    for i in range(matrix.shape[0]):
        # len recall = 6
        print("|%6d|" % (i), end='')
        for j in range(matrix.shape[0]):
            print("%7d|" % (matrix[i][j]), end='')
        print("%s" % (" " * (len(PRECISION) - 7)), end='')
        if precisionlist[i] != -1:
            print("%1.5f|" % precisionlist[i])
        else:
            print("%7s|" % "nan")

    print("|%s|" % ('-' * len(RECALL)), end='')
    for i in range(matrix.shape[0]):
        print("%s|" % ('-' * 7), end='')
    print("%s|" % ('-' * len(PRECISION)))
    print("|%s|" % ('Recall'), end='')

    for i in range(matrix.shape[0]):
        if recalllist[i] != -1:
            print("%1.5f|" % (recalllist[i]), end='')
        else:
            print("%7s|" % "nan", end='')

    print('%s|' % (' ' * len(PRECISION)))


def getPrecisionRecall(cmatrix, label=1):
    trueP = cmatrix[label][label]
    denom = np.sum(cmatrix, axis=0)[label]
    if denom == 0:
        denom = 1
    recall = trueP / denom
    denom = np.sum(cmatrix, axis=1)[label]
    if denom == 0:
        denom = 1
    precision = trueP / denom
    return precision, recall


def getMacroPrecisionRecall(cmatrix):
    # TP + FP
    precisionlist = np.sum(cmatrix, axis=1)
    # TP + FN
    recalllist = np.sum(cmatrix, axis=0)
    precisionlist__ = [cmatrix[i][i] / x if x !=
                       0 else 0 for i, x in enumerate(precisionlist)]
    recalllist__ = [cmatrix[i][i] / x if x !=
                    0 else 0 for i, x in enumerate(recalllist)]
    precision = np.sum(precisionlist__)
    precision /= len(precisionlist__)
    recall = np.sum(recalllist__)
    recall /= len(recalllist__)
    return precision, recall


def getMicroPrecisionRecall(cmatrix):
    # TP + FP
    precisionlist = np.sum(cmatrix, axis=1)
    # TP + FN
    recalllist = np.sum(cmatrix, axis=0)
    num = 0.0
    for i in range(len(cmatrix)):
        num += cmatrix[i][i]

    precision = num / np.sum(precisionlist)
    recall = num / np.sum(recalllist)
    return precision, recall


def getMacroMicroFScore(cmatrix):
    '''
    Returns macro and micro f-scores.
    Refer: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf
    '''
    precisionlist = np.sum(cmatrix, axis=1)
    recalllist = np.sum(cmatrix, axis=0)
    precisionlist__ = [cmatrix[i][i] / x if x !=
                       0 else 0 for i, x in enumerate(precisionlist)]
    recalllist__ = [cmatrix[i][i] / x if x !=
                    0 else 0 for i, x in enumerate(recalllist)]
    macro = 0.0
    for i in range(len(precisionlist)):
        denom = precisionlist__[i] + recalllist__[i]
        numer = precisionlist__[i] * recalllist__[i] * 2
        if denom == 0:
            denom = 1
        macro += numer / denom
    macro /= len(precisionlist)

    num = 0.0
    for i in range(len(precisionlist)):
        num += cmatrix[i][i]

    denom1 = np.sum(precisionlist)
    denom2 = np.sum(recalllist)
    pi = num / denom1
    rho = num / denom2
    denom = pi + rho
    if denom == 0:
        denom = 1
    micro = 2 * pi * rho / denom
    return macro, micro


class GraphManager:
    '''
    Manages saving and restoring graphs. Designed to be used with EMI-RNN
    though is general enough to be useful otherwise as well.
    '''

    def __init__(self):
        pass

    def checkpointModel(self, saver, sess, modelPrefix,
                        globalStep=1000, redirFile=None):
        saver.save(sess, modelPrefix, global_step=globalStep)
        print('Model saved to %s, global_step %d' % (modelPrefix, globalStep),
              file=redirFile)

    def loadCheckpoint(self, sess, modelPrefix, globalStep,
                       redirFile=None):
        metaname = modelPrefix + '-%d.meta' % globalStep
        basename = os.path.basename(metaname)
        fileList = os.listdir(os.path.dirname(modelPrefix))
        fileList = [x for x in fileList if x.startswith(basename)]
        assert len(fileList) > 0, 'Checkpoint file not found'
        msg = 'Too many or too few checkpoint files for globalStep: %d' % globalStep
        assert len(fileList) is 1, msg
        chkpt = basename + '/' + fileList[0]
        saver = tf.train.import_meta_graph(metaname)
        metaname = metaname[:-5]
        saver.restore(sess, metaname)
        graph = tf.get_default_graph()
        return graph

# Model Trainer - ProtoNN

In [4]:
#Trainer
class ProtoNNTrainer:
    def __init__(self, protoNNObj, regW, regB, regZ,
                 sparcityW, sparcityB, sparcityZ,
                 learningRate, X, Y, lossType='l2'):
        '''
        A wrapper for the various techniques used for training ProtoNN. This
        subsumes both the responsibility of loss graph construction and
        performing training. The original training routine that is part of the
        C++ implementation of EdgeML used iterative hard thresholding (IHT),
        gamma estimation through median heuristic and other tricks for
        training ProtoNN. This module implements the same in Tensorflow
        and python.

        protoNNObj: An instance of ProtoNN class defining the forward
            computation graph. The loss functions and training routines will be
            attached to this instance.
        regW, regB, regZ: Regularization constants for W, B, and
            Z matrices of protoNN.
        sparcityW, sparcityB, sparcityZ: Sparsity constraints
            for W, B and Z matrices. A value between 0 (exclusive) and 1
            (inclusive) is expected. A value of 1 indicates dense training.
        learningRate: Initial learning rate for ADAM optimizer.
        X, Y : Placeholders for data and labels.
            X [-1, featureDimension]
            Y [-1, num Labels]
        lossType: ['l2', 'xentropy']
        '''
        self.protoNNObj = protoNNObj
        self.__regW = regW
        self.__regB = regB
        self.__regZ = regZ
        self.__sW = sparcityW
        self.__sB = sparcityB
        self.__sZ = sparcityZ
        self.__lR = learningRate
        self.X = X
        self.Y = Y
        self.sparseTraining = True
        if (sparcityW == 1.0) and (sparcityB == 1.0) and (sparcityZ == 1.0):
            self.sparseTraining = False
            print("Sparse training disabled.", file=sys.stderr)
        # Define placeholders for sparse training
        self.W_th = None
        self.B_th = None
        self.Z_th = None
        self.__lossType = lossType
        self.__validInit = False
        self.__validInit = self.__validateInit()
        self.__protoNNOut = protoNNObj(X, Y)
        self.loss = self.__lossGraph()
        self.trainStep = self.__trainGraph()
        self.__hthOp = self.__getHardThresholdOp()
        self.accuracy = protoNNObj.getAccuracyOp()

    def __validateInit(self):
        self.__validInit = False
        msg = "Sparsity value should be between"
        msg += " 0 and 1 (both inclusive)."
        assert self.__sW >= 0. and self.__sW <= 1., 'W:' + msg
        assert self.__sB >= 0. and self.__sB <= 1., 'B:' + msg
        assert self.__sZ >= 0. and self.__sZ <= 1., 'Z:' + msg
        d, dcap, m, L, _ = self.protoNNObj.getHyperParams()
        msg = 'Y should be of dimension [-1, num labels/classes]'
        msg += ' specified as part of ProtoNN object.'
        assert (len(self.Y.shape)) == 2, msg
        assert (self.Y.shape[1] == L), msg
        msg = 'X should be of dimension [-1, featureDimension]'
        msg += ' specified as part of ProtoNN object.'
        assert (len(self.X.shape) == 2), msg
        assert (self.X.shape[1] == d), msg
        self.__validInit = True
        msg = 'Values can be \'l2\', or \'xentropy\''
        if self.__lossType not in ['l2', 'xentropy']:
            raise ValueError(msg)
        return True

    def __lossGraph(self):
        pnnOut = self.__protoNNOut
        l1, l2, l3 = self.__regW, self.__regB, self.__regZ
        W, B, Z, _ = self.protoNNObj.getModelMatrices()
        if self.__lossType == 'l2':
            with tf.name_scope('protonn-l2-loss'):
                loss_0 = tf.nn.l2_loss(self.Y - pnnOut)
                reg = l1 * tf.nn.l2_loss(W) + l2 * tf.nn.l2_loss(B)
                reg += l3 * tf.nn.l2_loss(Z)
                loss = loss_0 + reg
        elif self.__lossType == 'xentropy':
            with tf.name_scope('protonn-xentropy-loss'):
                loss_0 = tf.nn.softmax_cross_entropy_with_logits_v2(logits=pnnOut,
                                                         labels=tf.stop_gradient(self.Y))
                loss_0 = tf.reduce_mean(loss_0)
                reg = l1 * tf.nn.l2_loss(W) + l2 * tf.nn.l2_loss(B)
                reg += l3 * tf.nn.l2_loss(Z)
                loss = loss_0 + reg
        return loss

    def __trainGraph(self):
        with tf.name_scope('protonn-gradient-adam'):
            trainStep = tf.train.AdamOptimizer(self.__lR)
            trainStep = trainStep.minimize(self.loss)
        return trainStep

    def __getHardThresholdOp(self):
        W, B, Z, _ = self.protoNNObj.getModelMatrices()
        self.W_th = tf.placeholder(tf.float32, name='W_th')
        self.B_th = tf.placeholder(tf.float32, name='B_th')
        self.Z_th = tf.placeholder(tf.float32, name='Z_th')
        with tf.name_scope('hard-threshold-assignments'):
            hard_thrsd_W = W.assign(self.W_th)
            hard_thrsd_B = B.assign(self.B_th)
            hard_thrsd_Z = Z.assign(self.Z_th)
            hard_thrsd_op = tf.group(hard_thrsd_W, hard_thrsd_B, hard_thrsd_Z)
        return hard_thrsd_op

    def train(self, batchSize, totalEpochs, sess,
              x_train, x_val, y_train, y_val, noInit=False,
              redirFile=None, printStep=10, valStep=3):
        '''
        Performs dense training of ProtoNN followed by iterative hard
        thresholding to enforce sparsity constraints.

        batchSize: Batch size per update
        totalEpochs: The number of epochs to run training for. One epoch is
            defined as one pass over the entire training data.
        sess: The Tensorflow session to use for running various graph
            operators.
        x_train, x_val, y_train, y_val: The numpy array containing train and
            validation data. x data is assumed to in of shape [-1,
            featureDimension] while y should have shape [-1, numberLabels].
        noInit: By default, all the tensors of the computation graph are
        initialized at the start of the training session. Set noInit=False to
        disable this behaviour.
        printStep: Number of batches between echoing of loss and train accuracy.
        valStep: Number of epochs between evolutions on validation set.
        '''
        d, d_cap, m, L, gamma = self.protoNNObj.getHyperParams()
        assert batchSize >= 1, 'Batch size should be positive integer'
        assert totalEpochs >= 1, 'Total epochs should be positive integer'
        assert x_train.ndim == 2, 'Expected training data to be of rank 2'
        assert x_train.shape[1] == d, 'Expected x_train to be [-1, %d]' % d
        assert x_val.ndim == 2, 'Expected validation data to be of rank 2'
        assert x_val.shape[1] == d, 'Expected x_val to be [-1, %d]' % d
        assert y_train.ndim == 2, 'Expected training labels to be of rank 2'
        assert y_train.shape[1] == L, 'Expected y_train to be [-1, %d]' % L
        assert y_val.ndim == 2, 'Expected validation labels to be of rank 2'
        assert y_val.shape[1] == L, 'Expected y_val to be [-1, %d]' % L

        # Numpy will throw asserts for arrays
        if sess is None:
            raise ValueError('sess must be valid Tensorflow session.')

        trainNumBatches = int(np.ceil(len(x_train) / batchSize))
        valNumBatches = int(np.ceil(len(x_val) / batchSize))
        x_train_batches = np.array_split(x_train, trainNumBatches)
        y_train_batches = np.array_split(y_train, trainNumBatches)
        x_val_batches = np.array_split(x_val, valNumBatches)
        y_val_batches = np.array_split(y_val, valNumBatches)
        if not noInit:
            sess.run(tf.global_variables_initializer())
        X, Y = self.X, self.Y
        W, B, Z, _ = self.protoNNObj.getModelMatrices()
        for epoch in range(totalEpochs):
            for i in range(len(x_train_batches)):
                batch_x = x_train_batches[i]
                batch_y = y_train_batches[i]
                feed_dict = {
                    X: batch_x,
                    Y: batch_y
                }
                sess.run(self.trainStep, feed_dict=feed_dict)
                if i % printStep == 0:
                    loss, acc = sess.run([self.loss, self.accuracy],
                                         feed_dict=feed_dict)
                    msg = "Epoch: %3d Batch: %3d" % (epoch, i)
                    msg += " Loss: %3.5f Accuracy: %2.5f" % (loss, acc)
                    print(msg, file=redirFile)

            # Perform Hard thresholding
            if self.sparseTraining:
                W_, B_, Z_ = sess.run([W, B, Z])
                fd_thrsd = {
                    self.W_th: hardThreshold(W_, self.__sW),
                    self.B_th: hardThreshold(B_, self.__sB),
                    self.Z_th: hardThreshold(Z_, self.__sZ)
                }
                sess.run(self.__hthOp, feed_dict=fd_thrsd)

            if (epoch + 1) % valStep  == 0:
                acc = 0.0
                loss = 0.0
                for j in range(len(x_val_batches)):
                    batch_x = x_val_batches[j]
                    batch_y = y_val_batches[j]
                    feed_dict = {
                        X: batch_x,
                        Y: batch_y
                    }
                    acc_, loss_ = sess.run([self.accuracy, self.loss],
                                           feed_dict=feed_dict)
                    acc += acc_
                    loss += loss_
                acc /= len(y_val_batches)
                loss /= len(y_val_batches)
                print("Test Loss: %2.5f Accuracy: %2.5f" % (loss, acc))


# Model Graph - ProtoNN

In [5]:

class ProtoNN:
    def __init__(self, inputDimension, projectionDimension, numPrototypes,
                 numOutputLabels, gamma,
                 W = None, B = None, Z = None):
        '''
        Forward computation graph for ProtoNN.

        inputDimension: Input data dimension or feature dimension.
        projectionDimension: hyperparameter
        numPrototypes: hyperparameter
        numOutputLabels: The number of output labels or classes
        W, B, Z: Numpy matrices that can be used to initialize
            projection matrix(W), prototype matrix (B) and prototype labels
            matrix (B).
            Expected Dimensions:
                W   inputDimension (d) x projectionDimension (d_cap)
                B   projectionDimension (d_cap) x numPrototypes (m)
                Z   numOutputLabels (L) x numPrototypes (m)
        '''
        with tf.name_scope('protoNN') as ns:
            self.__nscope = ns
        self.__d = inputDimension
        self.__d_cap = projectionDimension
        self.__m = numPrototypes
        self.__L = numOutputLabels

        self.__inW = W
        self.__inB = B
        self.__inZ = Z
        self.__inGamma = gamma
        self.W, self.B, self.Z = None, None, None
        self.gamma = None

        self.__validInit = False
        self.__initWBZ()
        self.__initGamma()
        self.__validateInit()
        self.protoNNOut = None
        self.predictions = None
        self.accuracy = None

    def __validateInit(self):
        self.__validInit = False
        errmsg = "Dimensions mismatch! Should be W[d, d_cap]"
        errmsg += ", B[d_cap, m] and Z[L, m]"
        d, d_cap, m, L, _ = self.getHyperParams()
        assert self.W.shape[0] == d, errmsg
        assert self.W.shape[1] == d_cap, errmsg
        assert self.B.shape[0] == d_cap, errmsg
        assert self.B.shape[1] == m, errmsg
        assert self.Z.shape[0] == L, errmsg
        assert self.Z.shape[1] == m, errmsg
        self.__validInit = True

    def __initWBZ(self):
        with tf.name_scope(self.__nscope):
            W = self.__inW
            if W is None:
                W = tf.random_normal_initializer()
                W = W([self.__d, self.__d_cap])
            self.W = tf.Variable(W, name='W', dtype=tf.float32)

            B = self.__inB
            if B is None:
                B = tf.random_uniform_initializer()
                B = B([self.__d_cap, self.__m])
            self.B = tf.Variable(B, name='B', dtype=tf.float32)

            Z = self.__inZ
            if Z is None:
                Z = tf.random_normal_initializer()
                Z = Z([self.__L, self.__m])
            Z = tf.Variable(Z, name='Z', dtype=tf.float32)
            self.Z = Z
        return self.W, self.B, self.Z

    def __initGamma(self):
        with tf.name_scope(self.__nscope):
            gamma = self.__inGamma
            self.gamma = tf.constant(gamma, name='gamma')

    def getHyperParams(self):
        '''
        Returns the model hyperparameters:
            [inputDimension, projectionDimension,
            numPrototypes, numOutputLabels, gamma]
        '''
        d = self.__d
        dcap = self.__d_cap
        m = self.__m
        L = self.__L
        return d, dcap, m, L, self.gamma

    def getModelMatrices(self):
        '''
        Returns Tensorflow tensors of the model matrices, which
        can then be evaluated to obtain corresponding numpy arrays.

        These can then be exported as part of other implementations of
        ProtonNN, for instance a C++ implementation or pure python
        implementation.
        Returns
            [ProjectionMatrix (W), prototypeMatrix (B),
             prototypeLabelsMatrix (Z), gamma]
        '''
        return self.W, self.B, self.Z, self.gamma

    def __call__(self, X, Y=None):
        '''
        This method is responsible for construction of the forward computation
        graph. The end point of the computation graph, or in other words the
        output operator for the forward computation is returned. Additionally,
        if the argument Y is provided, a classification accuracy operator with
        Y as target will also be created. For this, Y is assumed to in one-hot
        encoded format and the class with the maximum prediction score is
        compared to the encoded class in Y.  This accuracy operator is returned
        by getAccuracyOp() method. If a different accuracyOp is required, it
        can be defined by overriding the createAccOp(protoNNScoresOut, Y)
        method.

        X: Input tensor or placeholder of shape [-1, inputDimension]
        Y: Optional tensor or placeholder for targets (labels or classes).
            Expected shape is [-1, numOutputLabels].
        returns: The forward computation outputs, self.protoNNOut
        '''
        # This should never execute
        assert self.__validInit is True, "Initialization failed!"
        if self.protoNNOut is not None:
            return self.protoNNOut

        W, B, Z, gamma = self.W, self.B, self.Z, self.gamma
        with tf.name_scope(self.__nscope):
            WX = tf.matmul(X, W)
            # Convert WX to tensor so that broadcasting can work
            dim = [-1, WX.shape.as_list()[1], 1]
            WX = tf.reshape(WX, dim)
            dim = [1, B.shape.as_list()[0], -1]
            B = tf.reshape(B, dim)
            l2sim = B - WX
            l2sim = tf.pow(l2sim, 2)
            l2sim = tf.reduce_sum(l2sim, 1, keepdims=True)
            self.l2sim = l2sim
            gammal2sim = (-1 * gamma * gamma) * l2sim
            M = tf.exp(gammal2sim)
            dim = [1] + Z.shape.as_list()
            Z = tf.reshape(Z, dim)
            y = tf.multiply(Z, M)
            y = tf.reduce_sum(y, 2, name='protoNNScoreOut')
            self.protoNNOut = y
            self.predictions = tf.argmax(y, 1, name='protoNNPredictions')
            if Y is not None:
                self.createAccOp(self.protoNNOut, Y)
        return y

    def createAccOp(self, outputs, target):
        '''
        Define an accuracy operation on ProtoNN's output scores and targets.
        Here a simple classification accuracy operator is defined. More
        complicated operators (for multiple label problems and so forth) can be
        defined by overriding this method
        '''
        assert self.predictions is not None
        target = tf.argmax(target, 1)
        correctPrediction = tf.equal(self.predictions, target)
        acc = tf.reduce_mean(tf.cast(correctPrediction, tf.float32),
                             name='protoNNAccuracy')
        self.accuracy = acc

    def getPredictionsOp(self):
        '''
        The predictions operator is defined as argmax(protoNNScores) for each
        prediction.
        '''
        return self.predictions

    def getAccuracyOp(self):
        '''
        returns accuracyOp as defined by createAccOp. It defaults to
        multi-class classification accuracy.
        '''
        msg = "Accuracy operator not defined in graph. Did you provide Y as an"
        msg += " argument to _call_?"
        assert self.accuracy is not None, msg
        return self.accuracy

# Obtain Data

It is assumed that the Daphnet data has already been downloaded,preprocessed and set up in subdirectory.

In [6]:
DATA_DIR = r"./experiments"
windowLen = 'data'
out = preprocessData(DATA_DIR,windowLen)
dataDimension = out[0]
numClasses = out[1]
x_train, y_train = out[2], out[3]
x_test, y_test = out[4], out[5]
print("Feature Dimension: ", dataDimension)
print("Num classes: ", numClasses)

Feature Dimension:  423
Num classes:  2


In [7]:
DATA_DIR = r"./experiments"
train, test = np.load(DATA_DIR + '/ttrain_data.npy'), np.load(DATA_DIR + '/ttest_data.npy')
x_train, y_train = train[:, 1:], train[:, 0]
x_test, y_test = test[:, 1:], test[:, 0]

numClasses = max(y_train) - min(y_train) + 1
numClasses = max(numClasses, max(y_test) - min(y_test) + 1)
numClasses = int(numClasses)

y_train = helper.to_onehot(y_train, numClasses)
y_test = helper.to_onehot(y_test, numClasses)
dataDimension = x_train.shape[1]
numClasses = y_train.shape[1]

# Model Parameters

Note that ProtoNN is very sensitive to the value of the hyperparameter $\gamma$, here stored in valiable `GAMMA`. If `GAMMA` is set to `None`, median heuristic will be used to estimate a good value of $\gamma$ through the `helper.getGamma()` method. This method also returns the corresponding `W` and `B` matrices which should be used to initialize ProtoNN (as is done here).

In [8]:
PROJECTION_DIM = 5 #d^
NUM_PROTOTYPES = 40 #m
REG_W = 0.000005
REG_B = 0.0
REG_Z = 0.00005
SPAR_W = 1.0
SPAR_B = 0.8
SPAR_Z = 0.8
LEARNING_RATE = 0.001
NUM_EPOCHS = 600
BATCH_SIZE = 32
GAMMA = 0.007586

In [9]:
W, B, gamma = getGamma(GAMMA, PROJECTION_DIM, dataDimension,
                       NUM_PROTOTYPES, x_train)

In [10]:
X = tf.placeholder(tf.float32, [None, dataDimension], name='X')
Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')

In [11]:
from sklearn.model_selection import cross_val_score
from sklearn.metrics import confusion_matrix,classification_report
from functools import partial

X = tf.placeholder(tf.float32, [None, dataDimension], name='X')
Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')
def objective(trial,x_train, x_test, y_train, y_test):
    W, B, gamma = getGamma(GAMMA, PROJECTION_DIM, dataDimension,
                       NUM_PROTOTYPES, x_train)
    # Inside the optimization function, you use the 'trial' object to suggest hyperparameters
    REG_W = trial.suggest_float('REG_W', 2e-6, 5e-6)
    REG_B = trial.suggest_float('REG_B', 0.0, 0.01)
    REG_Z = trial.suggest_float('REG_Z', 2e-5, 5e-5)
    SPAR_W = trial.suggest_float('SPAR_W', 0.5, 1.0)
    SPAR_B = trial.suggest_float('SPAR_B', 0.5, 1.0)
    SPAR_Z = trial.suggest_float('SPAR_Z', 0.5, 1.0)
    LEARNING_RATE = trial.suggest_float('LEARNING_RATE', 1e-4, 1e-3)
    NUM_EPOCHS = trial.suggest_int('NUM_EPOCHS', 200, 600)

    # Set the suggested hyperparameters in the trainer
    protoNN = ProtoNN(dataDimension, PROJECTION_DIM,
                  NUM_PROTOTYPES, numClasses,
                  gamma, W=W, B=B)
    trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,
                         SPAR_W, SPAR_B, SPAR_Z,
                         LEARNING_RATE, X, Y, lossType='xentropy')
    # Call your ProtoNN trainer function or use it as needed
    sess = tf.Session()

    trainer.train(BATCH_SIZE, NUM_EPOCHS, sess, x_train, x_test, y_train, y_test,printStep=600, valStep=10)
    acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test})
    pred = sess.run(protoNN.predictions, feed_dict={X: x_test, Y: y_test})
    # W, B, Z are tensorflow graph nodes
    W, B, Z, _ = protoNN.getModelMatrices()
    matrixList = sess.run([W, B, Z])
    sparcityList = [SPAR_W, SPAR_B, SPAR_Z]                       
    nnz, size, sparse = getModelSize(matrixList, sparcityList)
    y_test = np.argmax(y_test,axis=1)
    sensitivity = confusion_matrix(y_test,pred)[1][1]/(confusion_matrix(y_test,pred)[1][1] + confusion_matrix(y_test,pred)[1][0])
    specificity = confusion_matrix(y_test,pred)[0][0]/(confusion_matrix(y_test,pred)[0][0] + confusion_matrix(y_test,pred)[0][1])
    return (sensitivity+specificity)/2



In [12]:
import optuna
study = optuna.create_study(direction='maximize')


[I 2023-11-17 13:07:07,306] A new study created in memory with name: no-name-f0231465-89ff-4f90-a3ca-24098684743e


In [13]:
op_fun = partial(objective,x_train=x_train, x_test=x_test, y_train=y_train, y_test=y_test)
study.optimize(op_fun,n_trials=20)

Epoch:   0 Batch:   0 Loss: 0.28062 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.17970 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.15184 Accuracy: 1.00000
Epoch:   3 Batch:   0 Loss: 0.37073 Accuracy: 1.00000
Epoch:   4 Batch:   0 Loss: 0.87558 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.34357 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.64776 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.81831 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.90185 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.93356 Accuracy: 0.00000
Test Loss: 1.09768 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.93526 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.92006 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.89471 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.86478 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.83123 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.79649 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.76071 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.72452 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77258 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77259 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.77261 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.77264 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.77268 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.77272 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.77278 Accuracy: 0.00000
Test Loss: 0.70668 Accuracy: 0.50070
Epoch: 150 Batch:   0 Loss: 0.77285 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.77291 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.77298 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.77307 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.77315 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.77323 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.77332 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.77342 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.77352 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.77362 Accuracy: 0.00000
Test Loss: 0.70282 Accuracy: 0.50070
Epoch: 1

[I 2023-11-17 13:09:24,795] Trial 0 finished with value: 0.5004020908725372 and parameters: {'REG_W': 3.810381964194984e-06, 'REG_B': 0.008502786220215441, 'REG_Z': 3.65253497805578e-05, 'SPAR_W': 0.8805874465018864, 'SPAR_B': 0.6587546957396175, 'SPAR_Z': 0.5053907565086231, 'LEARNING_RATE': 0.00040739436645659944, 'NUM_EPOCHS': 260}. Best is trial 0 with value: 0.5004020908725372.


Test Loss: 0.68225 Accuracy: 0.50060
Epoch:   0 Batch:   0 Loss: 1.30456 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 4.70260 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 3.37755 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 3.04379 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.93987 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.88502 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.83778 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.78877 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.73576 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.67984 Accuracy: 0.00000
Test Loss: 1.41102 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.62035 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.55769 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.49231 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.42448 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.35303 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.28109 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.20536 Accuracy: 0.00000
Epoch:  

Epoch: 142 Batch:   0 Loss: 0.75975 Accuracy: 0.00000
Epoch: 143 Batch:   0 Loss: 0.75937 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.75900 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.75866 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.75833 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.75801 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.75771 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.75743 Accuracy: 0.00000
Test Loss: 0.69773 Accuracy: 0.50060
Epoch: 150 Batch:   0 Loss: 0.75716 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.75690 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.75665 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.75640 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.75616 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.75594 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.75572 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.75552 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.75531 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.75512 Accuracy:

Epoch: 285 Batch:   0 Loss: 0.74430 Accuracy: 0.00000
Epoch: 286 Batch:   0 Loss: 0.74428 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.74425 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.74422 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.74419 Accuracy: 0.00000
Test Loss: 0.72742 Accuracy: 0.50060
Epoch: 290 Batch:   0 Loss: 0.74416 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.74413 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.74410 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.74407 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.74405 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.74402 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.74399 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.74396 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.74393 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.74391 Accuracy: 0.00000
Test Loss: 0.73081 Accuracy: 0.50070
Epoch: 300 Batch:   0 Loss: 0.74388 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.74385 Accuracy: 0.00000
Epoch: 3

Epoch: 428 Batch:   0 Loss: 0.74221 Accuracy: 0.00000
Epoch: 429 Batch:   0 Loss: 0.74220 Accuracy: 0.00000
Test Loss: 0.77548 Accuracy: 0.50080
Epoch: 430 Batch:   0 Loss: 0.74219 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.74218 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.74217 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.74216 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.74214 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.74213 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.74212 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.74210 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.74209 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.74208 Accuracy: 0.00000
Test Loss: 0.77850 Accuracy: 0.50080
Epoch: 440 Batch:   0 Loss: 0.74207 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.74205 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.74204 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.74203 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.74202 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 13:14:19,416] Trial 1 finished with value: 0.5006031363088058 and parameters: {'REG_W': 2.3889116573607026e-06, 'REG_B': 0.00276641635228956, 'REG_Z': 4.095409275748267e-05, 'SPAR_W': 0.8364205006053882, 'SPAR_B': 0.9623193240893757, 'SPAR_Z': 0.7706814875793248, 'LEARNING_RATE': 0.0006869908732317194, 'NUM_EPOCHS': 538}. Best is trial 1 with value: 0.5006031363088058.


Epoch:   0 Batch:   0 Loss: 0.09184 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.45713 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 1.75579 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.23665 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.37615 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.42436 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.44473 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.45385 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.45652 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.45470 Accuracy: 0.00000
Test Loss: 1.15220 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.44929 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.44130 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.43106 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.41856 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.40382 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.38658 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.36713 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.34488 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77320 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77284 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.77249 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.77214 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.77181 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.77149 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.77119 Accuracy: 0.00000
Test Loss: 0.70910 Accuracy: 0.50090
Epoch: 150 Batch:   0 Loss: 0.77088 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.77059 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.77030 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.77003 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.76978 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.76953 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.76929 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.76907 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.76885 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.76864 Accuracy: 0.00000
Test Loss: 0.70457 Accuracy: 0.50090
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.74607 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.74597 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.74587 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.74576 Accuracy: 0.00000
Test Loss: 0.68045 Accuracy: 0.50100
Epoch: 290 Batch:   0 Loss: 0.74566 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.74557 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.74547 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.74538 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.74528 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.74520 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.74511 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.74502 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.74494 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.74486 Accuracy: 0.00000
Test Loss: 0.68131 Accuracy: 0.50100
Epoch: 300 Batch:   0 Loss: 0.74478 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.74470 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.74463 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.73927 Accuracy: 0.00000
Test Loss: 0.69745 Accuracy: 0.50100
Epoch: 430 Batch:   0 Loss: 0.73925 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.73922 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.73919 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.73917 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.73914 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.73911 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.73909 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.73906 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.73904 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.73901 Accuracy: 0.00000
Test Loss: 0.69887 Accuracy: 0.50100
Epoch: 440 Batch:   0 Loss: 0.73899 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.73896 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.73893 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.73891 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.73888 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.73886 Accuracy: 0.00000
Epoch: 4

Epoch: 571 Batch:   0 Loss: 0.73653 Accuracy: 0.00000
Epoch: 572 Batch:   0 Loss: 0.73652 Accuracy: 0.00000
Epoch: 573 Batch:   0 Loss: 0.73651 Accuracy: 0.00000
Epoch: 574 Batch:   0 Loss: 0.73651 Accuracy: 0.00000
Epoch: 575 Batch:   0 Loss: 0.73649 Accuracy: 0.00000
Epoch: 576 Batch:   0 Loss: 0.73648 Accuracy: 0.00000
Epoch: 577 Batch:   0 Loss: 0.73647 Accuracy: 0.00000
Epoch: 578 Batch:   0 Loss: 0.73646 Accuracy: 0.00000
Epoch: 579 Batch:   0 Loss: 0.73645 Accuracy: 0.00000
Test Loss: 0.71692 Accuracy: 0.50111
Epoch: 580 Batch:   0 Loss: 0.73644 Accuracy: 0.00000
Epoch: 581 Batch:   0 Loss: 0.73643 Accuracy: 0.00000


[I 2023-11-17 13:22:38,343] Trial 2 finished with value: 0.5009047044632087 and parameters: {'REG_W': 2.3811746750267102e-06, 'REG_B': 0.002946692532198456, 'REG_Z': 4.828542175111063e-05, 'SPAR_W': 0.953388755074015, 'SPAR_B': 0.7181145448240711, 'SPAR_Z': 0.8185762503600094, 'LEARNING_RATE': 0.00048259573711175347, 'NUM_EPOCHS': 582}. Best is trial 2 with value: 0.5009047044632087.


Epoch:   0 Batch:   0 Loss: 0.20565 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 3.78336 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 3.03371 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.78009 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.71030 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.70523 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.71731 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.72878 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.73408 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.73241 Accuracy: 0.00000
Test Loss: 1.42221 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.72461 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.71158 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.69480 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.67436 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.65119 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.62527 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.59770 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.56775 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77140 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77117 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.77094 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.77071 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.77048 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.77025 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.77005 Accuracy: 0.00000
Test Loss: 0.74199 Accuracy: 0.50060
Epoch: 150 Batch:   0 Loss: 0.76986 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.76967 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.76949 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.76930 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.76915 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.76899 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.76882 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.76864 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.76848 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.76832 Accuracy: 0.00000
Test Loss: 0.73827 Accuracy: 0.50060
Epoch: 1

[I 2023-11-17 13:25:42,449] Trial 3 finished with value: 0.5004020908725372 and parameters: {'REG_W': 2.5413170126742132e-06, 'REG_B': 0.006482346810895375, 'REG_Z': 3.838850424544486e-05, 'SPAR_W': 0.5495479131542118, 'SPAR_B': 0.7076576660375489, 'SPAR_Z': 0.56169644186905, 'LEARNING_RATE': 0.0006389569712456179, 'NUM_EPOCHS': 213}. Best is trial 2 with value: 0.5009047044632087.


Epoch:   0 Batch:   0 Loss: 0.12429 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.08299 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.09090 Accuracy: 1.00000
Epoch:   3 Batch:   0 Loss: 0.34594 Accuracy: 1.00000
Epoch:   4 Batch:   0 Loss: 0.89466 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.37802 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.68166 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.85007 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.93602 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.97600 Accuracy: 0.00000
Test Loss: 1.11332 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.99065 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.99188 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.98578 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.97543 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.96269 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.94865 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.93381 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.91830 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.76358 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.76415 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.76470 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.76529 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.76585 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.76642 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.76700 Accuracy: 0.00000
Test Loss: 0.77201 Accuracy: 0.50070
Epoch: 150 Batch:   0 Loss: 0.76759 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.76818 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.76877 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.76937 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.76997 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.77055 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.77113 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.77170 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.77228 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.77286 Accuracy: 0.00000
Test Loss: 0.76189 Accuracy: 0.50070
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.76553 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.76538 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.76523 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.76508 Accuracy: 0.00000
Test Loss: 0.69891 Accuracy: 0.50090
Epoch: 290 Batch:   0 Loss: 0.76493 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.76478 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.76462 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.76447 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.76432 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.76417 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.76402 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.76386 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.76370 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.76354 Accuracy: 0.00000
Test Loss: 0.69733 Accuracy: 0.50090
Epoch: 300 Batch:   0 Loss: 0.76339 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.76323 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.76308 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.75056 Accuracy: 0.00000
Test Loss: 0.69560 Accuracy: 0.50080
Epoch: 430 Batch:   0 Loss: 0.75049 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.75042 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.75035 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.75028 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.75021 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.75015 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.75008 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.75001 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.74995 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.74989 Accuracy: 0.00000
Test Loss: 0.69620 Accuracy: 0.50080
Epoch: 440 Batch:   0 Loss: 0.74983 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.74977 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.74973 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.74968 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.74963 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.74958 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 13:33:55,791] Trial 4 finished with value: 0.5005026135906715 and parameters: {'REG_W': 2.106554995596043e-06, 'REG_B': 0.0035354041069576925, 'REG_Z': 4.577970493534107e-05, 'SPAR_W': 0.5637900188025731, 'SPAR_B': 0.6363861051363228, 'SPAR_Z': 0.5125862107227793, 'LEARNING_RATE': 0.00040531390559979413, 'NUM_EPOCHS': 534}. Best is trial 2 with value: 0.5009047044632087.


Epoch:   0 Batch:   0 Loss: 8.82441 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 3.40938 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.32924 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.08432 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.05501 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.07401 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.09765 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.11767 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.13282 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.14410 Accuracy: 0.00000
Test Loss: 1.06641 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.15193 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.15684 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.15966 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.16079 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.16063 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.15937 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.15708 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.15389 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.80158 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.80115 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.80074 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.80036 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.79999 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.79963 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.79929 Accuracy: 0.00000
Test Loss: 0.71568 Accuracy: 0.50080
Epoch: 150 Batch:   0 Loss: 0.79895 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.79867 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.79837 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.79809 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.79783 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.79760 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.79738 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.79717 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.79698 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.79680 Accuracy: 0.00000
Test Loss: 0.70921 Accuracy: 0.50080
Epoch: 1

[I 2023-11-17 13:37:08,848] Trial 5 finished with value: 0.5006031363088058 and parameters: {'REG_W': 2.0116378258662663e-06, 'REG_B': 0.004867608452905099, 'REG_Z': 4.359075766970931e-05, 'SPAR_W': 0.69839053336034, 'SPAR_B': 0.5014467537488719, 'SPAR_Z': 0.7461619806724386, 'LEARNING_RATE': 0.0003850109092687232, 'NUM_EPOCHS': 204}. Best is trial 2 with value: 0.5009047044632087.


Epoch:   0 Batch:   0 Loss: 14.61238 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 8.29134 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 3.19560 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.20260 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.11150 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.13666 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.16771 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.19142 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.20799 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.21890 Accuracy: 0.00000
Test Loss: 1.10526 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.22530 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.22768 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.22660 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.22253 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.21582 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.20675 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.19556 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.18244 Accuracy

Epoch: 143 Batch:   0 Loss: 0.79998 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.79982 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.79966 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.79948 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.79930 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.79914 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.79897 Accuracy: 0.00000
Test Loss: 0.69714 Accuracy: 0.50070
Epoch: 150 Batch:   0 Loss: 0.79880 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.79865 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.79850 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.79836 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.79824 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.79812 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.79801 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.79793 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.79785 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.79777 Accuracy: 0.00000
Test Loss: 0.69255 Accuracy: 0.50070
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.78323 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.78266 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.78209 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.78154 Accuracy: 0.00000
Test Loss: 0.66690 Accuracy: 0.50090
Epoch: 290 Batch:   0 Loss: 0.78101 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.78049 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.77998 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.77949 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.77901 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.77854 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.77808 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.77763 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.77720 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.77679 Accuracy: 0.00000
Test Loss: 0.66672 Accuracy: 0.50090
Epoch: 300 Batch:   0 Loss: 0.77638 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.77599 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.77562 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 13:41:41,445] Trial 6 finished with value: 0.5004020908725372 and parameters: {'REG_W': 4.121217309324934e-06, 'REG_B': 0.008181582116268727, 'REG_Z': 2.371372933028167e-05, 'SPAR_W': 0.7798664321781542, 'SPAR_B': 0.5633915948914334, 'SPAR_Z': 0.877669365611784, 'LEARNING_RATE': 0.0004288571397523209, 'NUM_EPOCHS': 342}. Best is trial 2 with value: 0.5009047044632087.


Epoch:   0 Batch:   0 Loss: 0.11542 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 1.74523 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.79806 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 3.09930 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 3.18378 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 3.20157 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 3.19225 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 3.16788 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 3.13110 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 3.08805 Accuracy: 0.00000
Test Loss: 1.57659 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 3.03720 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.97933 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.91423 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.84168 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.76352 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.67659 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.58036 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.47720 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77252 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77189 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.77131 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.77075 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.77023 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.76975 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.76929 Accuracy: 0.00000
Test Loss: 0.69433 Accuracy: 0.50050
Epoch: 150 Batch:   0 Loss: 0.76887 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.76844 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.76802 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.76764 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.76727 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.76690 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.76657 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.76624 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.76592 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.76564 Accuracy: 0.00000
Test Loss: 0.69451 Accuracy: 0.50050
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.75607 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.75605 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.75604 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.75602 Accuracy: 0.00000
Test Loss: 0.71916 Accuracy: 0.50070
Epoch: 290 Batch:   0 Loss: 0.75601 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.75600 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.75599 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.75597 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.75596 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.75595 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.75593 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.75592 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.75591 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.75590 Accuracy: 0.00000
Test Loss: 0.72120 Accuracy: 0.50070
Epoch: 300 Batch:   0 Loss: 0.75589 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.75588 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.75587 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 13:46:02,427] Trial 7 finished with value: 0.5005026135906715 and parameters: {'REG_W': 4.412697676135592e-06, 'REG_B': 0.003198383489368524, 'REG_Z': 3.552689934474277e-05, 'SPAR_W': 0.7978297845104647, 'SPAR_B': 0.84924161136927, 'SPAR_Z': 0.8079038266929486, 'LEARNING_RATE': 0.0008210480292785842, 'NUM_EPOCHS': 384}. Best is trial 2 with value: 0.5009047044632087.


Epoch:   0 Batch:   0 Loss: 5.75914 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 2.83749 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.86660 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.92428 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.95118 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.95749 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.95039 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.93410 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.91105 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.88296 Accuracy: 0.00000
Test Loss: 1.43994 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.84909 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.81210 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.76845 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.72043 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.66657 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.60484 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.53730 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 2.46276 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77212 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77189 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.77160 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.77125 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.77087 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.77044 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.76999 Accuracy: 0.00000
Test Loss: 0.69696 Accuracy: 0.50111
Epoch: 150 Batch:   0 Loss: 0.76953 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.76905 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.76855 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.76803 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.76750 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.76698 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.76645 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.76591 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.76539 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.76488 Accuracy: 0.00000
Test Loss: 0.69199 Accuracy: 0.50111
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.74019 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.74015 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.74012 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.74008 Accuracy: 0.00000
Test Loss: 0.71432 Accuracy: 0.50090
Epoch: 290 Batch:   0 Loss: 0.74005 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.74003 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.74000 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.73998 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.73995 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.73992 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.73989 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.73987 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.73984 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.73981 Accuracy: 0.00000
Test Loss: 0.71690 Accuracy: 0.50090
Epoch: 300 Batch:   0 Loss: 0.73979 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.73976 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.73974 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 13:50:17,343] Trial 8 finished with value: 0.5007036590269401 and parameters: {'REG_W': 2.2931800644736617e-06, 'REG_B': 0.006665022716907156, 'REG_Z': 4.5094551587690154e-05, 'SPAR_W': 0.7593484805210193, 'SPAR_B': 0.6159196429129663, 'SPAR_Z': 0.8862216677271006, 'LEARNING_RATE': 0.0007529488504011111, 'NUM_EPOCHS': 381}. Best is trial 2 with value: 0.5009047044632087.


Epoch:   0 Batch:   0 Loss: 0.66041 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 1.17187 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.64755 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.89058 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.98747 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.01348 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.00963 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.99545 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.97919 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.96393 Accuracy: 0.00000
Test Loss: 1.02673 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.95033 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.93822 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.92728 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.91701 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.90715 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.89737 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.88757 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.87761 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.82107 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.82033 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.81960 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.81887 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.81815 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.81744 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.81673 Accuracy: 0.00000
Test Loss: 0.70508 Accuracy: 0.50050
Epoch: 150 Batch:   0 Loss: 0.81602 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.81533 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.81464 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.81399 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.81329 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.81266 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.81204 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.81143 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.81081 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.81020 Accuracy: 0.00000
Test Loss: 0.69969 Accuracy: 0.50060
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.79349 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.79337 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.79326 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.79315 Accuracy: 0.00000
Test Loss: 0.66098 Accuracy: 0.50070
Epoch: 290 Batch:   0 Loss: 0.79305 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.79296 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.79287 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.79280 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.79275 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.79271 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.79267 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.79264 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.79262 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.79262 Accuracy: 0.00000
Test Loss: 0.65957 Accuracy: 0.50070
Epoch: 300 Batch:   0 Loss: 0.79262 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.79263 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.79263 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.76232 Accuracy: 0.00000
Test Loss: 0.65583 Accuracy: 0.50060
Epoch: 430 Batch:   0 Loss: 0.76219 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.76207 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.76193 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.76180 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.76166 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.76153 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.76141 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.76128 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.76117 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.76105 Accuracy: 0.00000
Test Loss: 0.65671 Accuracy: 0.50060
Epoch: 440 Batch:   0 Loss: 0.76094 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.76083 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.76072 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.76061 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.76051 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.76040 Accuracy: 0.00000
Epoch: 4

Epoch: 571 Batch:   0 Loss: 0.75069 Accuracy: 0.00000
Epoch: 572 Batch:   0 Loss: 0.75063 Accuracy: 0.00000
Epoch: 573 Batch:   0 Loss: 0.75056 Accuracy: 0.00000
Epoch: 574 Batch:   0 Loss: 0.75050 Accuracy: 0.00000
Epoch: 575 Batch:   0 Loss: 0.75044 Accuracy: 0.00000
Epoch: 576 Batch:   0 Loss: 0.75038 Accuracy: 0.00000
Epoch: 577 Batch:   0 Loss: 0.75032 Accuracy: 0.00000
Epoch: 578 Batch:   0 Loss: 0.75026 Accuracy: 0.00000
Epoch: 579 Batch:   0 Loss: 0.75020 Accuracy: 0.00000
Test Loss: 0.67309 Accuracy: 0.50060
Epoch: 580 Batch:   0 Loss: 0.75014 Accuracy: 0.00000


[I 2023-11-17 13:58:41,307] Trial 9 finished with value: 0.5004020908725372 and parameters: {'REG_W': 2.4810859452802976e-06, 'REG_B': 0.00988922374826537, 'REG_Z': 3.261693080385779e-05, 'SPAR_W': 0.6028177136949872, 'SPAR_B': 0.7005312263694913, 'SPAR_Z': 0.5302911886348941, 'LEARNING_RATE': 0.00031652384749609113, 'NUM_EPOCHS': 581}. Best is trial 2 with value: 0.5009047044632087.


Epoch:   0 Batch:   0 Loss: 0.00596 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.00624 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.01096 Accuracy: 1.00000
Epoch:   3 Batch:   0 Loss: 0.05319 Accuracy: 1.00000
Epoch:   4 Batch:   0 Loss: 0.27952 Accuracy: 1.00000
Epoch:   5 Batch:   0 Loss: 0.67344 Accuracy: 1.00000
Epoch:   6 Batch:   0 Loss: 0.91267 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.00609 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.04003 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.05301 Accuracy: 0.00000
Test Loss: 0.73942 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.05820 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.05956 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.05875 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.05629 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.05258 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.04788 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.04236 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.03613 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.28256 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.28033 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.27816 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.27602 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.27394 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.27190 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.26990 Accuracy: 1.00000
Test Loss: 0.54259 Accuracy: 0.75730
Epoch: 150 Batch:   0 Loss: 0.26795 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.26603 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.26417 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.26235 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.26056 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.25881 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.25710 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.25543 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.25378 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.25219 Accuracy: 1.00000
Test Loss: 0.54002 Accuracy: 0.75881
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.18507 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.18497 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.18481 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.18475 Accuracy: 1.00000
Test Loss: 0.53022 Accuracy: 0.77718
Epoch: 290 Batch:   0 Loss: 0.18466 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.18457 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.18448 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.18439 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.18430 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.18421 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.18413 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.18404 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.18396 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.18389 Accuracy: 1.00000
Test Loss: 0.53016 Accuracy: 0.77889
Epoch: 300 Batch:   0 Loss: 0.18380 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.18372 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.18364 Accuracy: 1.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.17870 Accuracy: 1.00000
Test Loss: 0.53629 Accuracy: 0.78320
Epoch: 430 Batch:   0 Loss: 0.17867 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.17864 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.17861 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.17858 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.17855 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.17852 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.17849 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.17846 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.17843 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.17840 Accuracy: 1.00000
Test Loss: 0.53721 Accuracy: 0.78340
Epoch: 440 Batch:   0 Loss: 0.17837 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.17834 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.17830 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.17827 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.17824 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.17821 Accuracy: 1.00000
Epoch: 4

[I 2023-11-17 14:04:12,448] Trial 10 finished with value: 0.7833735424205871 and parameters: {'REG_W': 3.1651209759814418e-06, 'REG_B': 1.5029243127284535e-05, 'REG_Z': 4.9678837029664995e-05, 'SPAR_W': 0.9694619549595722, 'SPAR_B': 0.8063620879846112, 'SPAR_Z': 0.9979067852308309, 'LEARNING_RATE': 0.0001202698283373907, 'NUM_EPOCHS': 484}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 20.02424 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 18.53956 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 16.66841 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 14.76781 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 12.84827 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 10.89393 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 8.94775 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 7.02891 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 5.15114 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 3.38570 Accuracy: 0.00000
Test Loss: 0.99467 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.05309 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.41565 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.18329 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.11014 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.09162 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.09011 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.09277 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.09561 Acc

Epoch: 143 Batch:   0 Loss: 0.61245 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.61028 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.60814 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.60604 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.60398 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.60195 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.59996 Accuracy: 1.00000
Test Loss: 0.66309 Accuracy: 0.69732
Epoch: 150 Batch:   0 Loss: 0.59801 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.59608 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.59419 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.59232 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.59049 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.58869 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.58693 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.58519 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.58348 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.58180 Accuracy: 1.00000
Test Loss: 0.66035 Accuracy: 0.70626
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.47860 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.47817 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.47774 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.47732 Accuracy: 1.00000
Test Loss: 0.64563 Accuracy: 0.74262
Epoch: 290 Batch:   0 Loss: 0.47689 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.47647 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.47606 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.47564 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.47523 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.47478 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.47433 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.47387 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.47343 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.47298 Accuracy: 1.00000
Test Loss: 0.64480 Accuracy: 0.74413
Epoch: 300 Batch:   0 Loss: 0.47254 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.47211 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.47168 Accuracy: 1.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.43412 Accuracy: 1.00000
Test Loss: 0.63596 Accuracy: 0.76081
Epoch: 430 Batch:   0 Loss: 0.43389 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.43367 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.43344 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.43322 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.43299 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.43277 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.43255 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.43233 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.43211 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.43189 Accuracy: 1.00000
Test Loss: 0.63534 Accuracy: 0.76212
Epoch: 440 Batch:   0 Loss: 0.43167 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.43145 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.43123 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.43101 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.43079 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.43057 Accuracy: 1.00000
Epoch: 4

[I 2023-11-17 14:09:29,082] Trial 11 finished with value: 0.764173703256936 and parameters: {'REG_W': 3.2395899592085873e-06, 'REG_B': 0.0003664057333160444, 'REG_Z': 4.973956832128323e-05, 'SPAR_W': 0.9837845128293876, 'SPAR_B': 0.7837171297182737, 'SPAR_Z': 0.988068899370609, 'LEARNING_RATE': 0.0001012807558511427, 'NUM_EPOCHS': 473}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 4.40007 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 3.03392 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.88292 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.35881 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.18530 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.13539 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.12351 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.12198 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.12246 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.12269 Accuracy: 0.00000
Test Loss: 0.74729 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.12206 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.12062 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.11843 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.11544 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.11167 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.10702 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.10166 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.09554 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.54054 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.53949 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.53846 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.53745 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.53645 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.53548 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.53452 Accuracy: 1.00000
Test Loss: 0.64857 Accuracy: 0.73269
Epoch: 150 Batch:   0 Loss: 0.53358 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.53265 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.53174 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.53085 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.52997 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.52910 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.52825 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.52741 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.52659 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.52579 Accuracy: 1.00000
Test Loss: 0.64810 Accuracy: 0.73389
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.47358 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.47335 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.47311 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.47288 Accuracy: 1.00000
Test Loss: 0.64336 Accuracy: 0.75267
Epoch: 290 Batch:   0 Loss: 0.47265 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.47243 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.47220 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.47198 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.47175 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.47153 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.47131 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.47109 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.47087 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.47065 Accuracy: 1.00000
Test Loss: 0.64293 Accuracy: 0.75318
Epoch: 300 Batch:   0 Loss: 0.47040 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.47016 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.46992 Accuracy: 1.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.44705 Accuracy: 1.00000
Test Loss: 0.63666 Accuracy: 0.76322
Epoch: 430 Batch:   0 Loss: 0.44690 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.44676 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.44661 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.44646 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.44631 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.44616 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.44601 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.44586 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.44571 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.44556 Accuracy: 1.00000
Test Loss: 0.63624 Accuracy: 0.76372
Epoch: 440 Batch:   0 Loss: 0.44542 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.44527 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.44512 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.44497 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.44483 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.44468 Accuracy: 1.00000
Epoch: 4

[I 2023-11-17 14:14:51,984] Trial 12 finished with value: 0.7646763168476076 and parameters: {'REG_W': 3.220491817791798e-06, 'REG_B': 0.0005350984518690573, 'REG_Z': 4.898567836971721e-05, 'SPAR_W': 0.9871943753190696, 'SPAR_B': 0.7865101139326912, 'SPAR_Z': 0.9854272311827109, 'LEARNING_RATE': 0.00010566103743412826, 'NUM_EPOCHS': 451}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 1.96295 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 1.43166 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.17168 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.09669 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.07652 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.07033 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.06668 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.06263 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.05761 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.05168 Accuracy: 0.00000
Test Loss: 0.72899 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.04509 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.03776 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.02999 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.02203 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.01411 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.00595 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 0.99780 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 0.98988 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.21845 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.21662 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.21484 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.21310 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.21140 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.20974 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.20811 Accuracy: 1.00000
Test Loss: 0.53152 Accuracy: 0.75043
Epoch: 150 Batch:   0 Loss: 0.20652 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.20495 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.20342 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.20192 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.20045 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.19902 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.19761 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.19623 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.19488 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.19356 Accuracy: 1.00000
Test Loss: 0.53026 Accuracy: 0.75244
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.13394 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.13384 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.13375 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.13365 Accuracy: 1.00000
Test Loss: 0.52572 Accuracy: 0.77193
Epoch: 290 Batch:   0 Loss: 0.13356 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.13347 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.13338 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.13329 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.13321 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.13312 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.13303 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.13296 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.13288 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.13281 Accuracy: 1.00000
Test Loss: 0.52549 Accuracy: 0.77203
Epoch: 300 Batch:   0 Loss: 0.13274 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.13267 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.13260 Accuracy: 1.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.12971 Accuracy: 1.00000
Test Loss: 0.52681 Accuracy: 0.78298
Epoch: 430 Batch:   0 Loss: 0.12971 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.12970 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.12970 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.12969 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.12969 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.12968 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.12967 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.12967 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.12967 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.12966 Accuracy: 1.00000
Test Loss: 0.52723 Accuracy: 0.78358
Epoch: 440 Batch:   0 Loss: 0.12965 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.12964 Accuracy: 1.00000


[I 2023-11-17 14:19:25,821] Trial 13 finished with value: 0.783373542420587 and parameters: {'REG_W': 3.2463815718835868e-06, 'REG_B': 3.572113501268689e-06, 'REG_Z': 4.9983371351560104e-05, 'SPAR_W': 0.9243469318758182, 'SPAR_B': 0.8076493738250382, 'SPAR_Z': 0.99383634206201, 'LEARNING_RATE': 0.0001038928417897229, 'NUM_EPOCHS': 442}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 3.46925 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 1.85211 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.61304 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.62114 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.63809 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.64711 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.64986 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.64857 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.64366 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.63586 Accuracy: 0.00000
Test Loss: 0.86037 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.62592 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.61415 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.60080 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.58615 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.57044 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.55370 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.53636 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.51819 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.50519 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.50412 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.50303 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.50197 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.50094 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.49993 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.49895 Accuracy: 1.00000
Test Loss: 0.70160 Accuracy: 0.72687
Epoch: 150 Batch:   0 Loss: 0.49798 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.49704 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.49610 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.49519 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.49429 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.49340 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.49253 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.49167 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.49082 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.48998 Accuracy: 1.00000
Test Loss: 0.69905 Accuracy: 0.72878
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.42543 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.42507 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.42470 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.42434 Accuracy: 1.00000
Test Loss: 0.67255 Accuracy: 0.75470
Epoch: 290 Batch:   0 Loss: 0.42398 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.42362 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.42326 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.42290 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.42255 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.42220 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.42185 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.42150 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.42114 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.42079 Accuracy: 1.00000
Test Loss: 0.67112 Accuracy: 0.75560
Epoch: 300 Batch:   0 Loss: 0.42045 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.42010 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.41975 Accuracy: 1.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.38098 Accuracy: 1.00000
Test Loss: 0.65956 Accuracy: 0.76393
Epoch: 430 Batch:   0 Loss: 0.38071 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.38044 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.38016 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.37989 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.37962 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.37935 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.37908 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.37881 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.37854 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.37827 Accuracy: 1.00000
Test Loss: 0.65900 Accuracy: 0.76424
Epoch: 440 Batch:   0 Loss: 0.37801 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.37773 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.37746 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.37720 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.37693 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.37667 Accuracy: 1.00000
Epoch: 4

[I 2023-11-17 14:24:30,245] Trial 14 finished with value: 0.7647768395657419 and parameters: {'REG_W': 4.9443223762717384e-06, 'REG_B': 0.00023673683599163402, 'REG_Z': 4.9885120979036455e-05, 'SPAR_W': 0.8947625074756174, 'SPAR_B': 0.8462338809988931, 'SPAR_Z': 0.9883925840878015, 'LEARNING_RATE': 0.00023019529772352665, 'NUM_EPOCHS': 458}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 5.33966 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 2.49974 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.62745 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.51148 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.51585 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.52937 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.53747 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.54032 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.53911 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.53498 Accuracy: 0.00000
Test Loss: 0.88718 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.52861 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.52052 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.51116 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.50097 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.49036 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.47944 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.46821 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.45714 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.79314 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.79274 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.79237 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.79205 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.79174 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.79148 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.79121 Accuracy: 0.00000
Test Loss: 0.77117 Accuracy: 0.50040
Epoch: 150 Batch:   0 Loss: 0.79100 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.79081 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.79063 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.79044 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.79027 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.79008 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.78989 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.78970 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.78950 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.78928 Accuracy: 0.00000
Test Loss: 0.75255 Accuracy: 0.50050
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.78083 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.78072 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.78058 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.78041 Accuracy: 0.00000
Test Loss: 0.65984 Accuracy: 0.50060
Epoch: 290 Batch:   0 Loss: 0.78021 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.77997 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.77971 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.77943 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.77912 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.77878 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.77843 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.77806 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.77768 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.77728 Accuracy: 0.00000
Test Loss: 0.65792 Accuracy: 0.50060
Epoch: 300 Batch:   0 Loss: 0.77688 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.77647 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.77605 Accuracy: 0.00000
Epoch: 3

[I 2023-11-17 14:27:54,413] Trial 15 finished with value: 0.5005026135906715 and parameters: {'REG_W': 2.9817944061172906e-06, 'REG_B': 0.0015764948468564553, 'REG_Z': 4.2455140492222395e-05, 'SPAR_W': 0.9216345594073406, 'SPAR_B': 0.8740899737578733, 'SPAR_Z': 0.9224633223709945, 'LEARNING_RATE': 0.0002329813134211586, 'NUM_EPOCHS': 311}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 4.09905 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 3.22776 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 3.46697 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 3.50745 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 3.49307 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 3.45398 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 3.39392 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 3.31812 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 3.22441 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 3.11006 Accuracy: 0.00000
Test Loss: 1.72716 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 2.97928 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.83617 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.68179 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.51929 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.35450 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.18927 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 2.02767 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.87293 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77098 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77022 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.76940 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.76863 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.76792 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.76725 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.76652 Accuracy: 0.00000
Test Loss: 0.73979 Accuracy: 0.50050
Epoch: 150 Batch:   0 Loss: 0.76567 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.76475 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.76380 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.76288 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.76199 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.76117 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.76043 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.75976 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.75915 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.75859 Accuracy: 0.00000
Test Loss: 0.73837 Accuracy: 0.50050
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.75139 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.75138 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.75138 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.75138 Accuracy: 0.00000
Test Loss: 0.77610 Accuracy: 0.50080
Epoch: 290 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Test Loss: 0.77909 Accuracy: 0.50080
Epoch: 300 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.75137 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.75221 Accuracy: 0.00000
Test Loss: 0.80848 Accuracy: 0.50090
Epoch: 430 Batch:   0 Loss: 0.75222 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.75223 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.75223 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.75224 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.75225 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.75226 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.75227 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.75227 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.75228 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.75229 Accuracy: 0.00000
Test Loss: 0.80999 Accuracy: 0.50090
Epoch: 440 Batch:   0 Loss: 0.75230 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.75231 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.75232 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.75233 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.75235 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.75236 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 14:34:31,309] Trial 16 finished with value: 0.5007036590269401 and parameters: {'REG_W': 3.563427719034845e-06, 'REG_B': 0.001817478468648327, 'REG_Z': 4.590243215292866e-05, 'SPAR_W': 0.864984424398386, 'SPAR_B': 0.770681142610384, 'SPAR_Z': 0.9182472389817082, 'LEARNING_RATE': 0.0009987435328399474, 'NUM_EPOCHS': 500}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 0.01813 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.22860 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 1.05089 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.37349 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.44603 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.46450 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.46861 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.46582 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.45860 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.44818 Accuracy: 0.00000
Test Loss: 0.82225 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.43534 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.42049 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.40457 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.38785 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.37064 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.35345 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.33600 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.31862 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.34817 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.34740 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.34666 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.34593 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.34523 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.34452 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.34383 Accuracy: 1.00000
Test Loss: 0.61546 Accuracy: 0.75500
Epoch: 150 Batch:   0 Loss: 0.34316 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.34249 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.34184 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.34120 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.34058 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.33996 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.33936 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.33878 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.33820 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.33764 Accuracy: 1.00000
Test Loss: 0.61275 Accuracy: 0.75681
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.29923 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.29901 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.29879 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.29856 Accuracy: 1.00000
Test Loss: 0.59643 Accuracy: 0.77367
Epoch: 290 Batch:   0 Loss: 0.29834 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.29812 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.29789 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.29766 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.29743 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.29721 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.29699 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.29676 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.29653 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.29631 Accuracy: 1.00000
Test Loss: 0.59641 Accuracy: 0.77478
Epoch: 300 Batch:   0 Loss: 0.29608 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.29585 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.29562 Accuracy: 1.00000
Epoch: 3

[I 2023-11-17 14:41:08,471] Trial 17 finished with value: 0.7770406111781263 and parameters: {'REG_W': 2.987979066384924e-06, 'REG_B': 7.777710356943696e-05, 'REG_Z': 4.161294467229353e-05, 'SPAR_W': 0.9368568088782723, 'SPAR_B': 0.9440713822183731, 'SPAR_Z': 0.9447121506984012, 'LEARNING_RATE': 0.0002062157637362075, 'NUM_EPOCHS': 426}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 0.50034 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 2.36514 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 2.70076 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 2.68985 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 2.64494 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 2.60123 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 2.55941 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 2.51721 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 2.47312 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 2.42613 Accuracy: 0.00000
Test Loss: 1.26310 Accuracy: 0.50030
Epoch:  10 Batch:   0 Loss: 2.37555 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 2.32113 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 2.26289 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 2.20084 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 2.13524 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 2.06662 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.99528 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.92181 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.77382 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.77366 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.77345 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.77320 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.77292 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.77261 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.77229 Accuracy: 0.00000
Test Loss: 0.70409 Accuracy: 0.50100
Epoch: 150 Batch:   0 Loss: 0.77197 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.77166 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.77135 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.77106 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.77080 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.77055 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.77032 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.77012 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.76992 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.76973 Accuracy: 0.00000
Test Loss: 0.70054 Accuracy: 0.50100
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.74500 Accuracy: 0.00000
Epoch: 287 Batch:   0 Loss: 0.74494 Accuracy: 0.00000
Epoch: 288 Batch:   0 Loss: 0.74488 Accuracy: 0.00000
Epoch: 289 Batch:   0 Loss: 0.74482 Accuracy: 0.00000
Test Loss: 0.70765 Accuracy: 0.50080
Epoch: 290 Batch:   0 Loss: 0.74476 Accuracy: 0.00000
Epoch: 291 Batch:   0 Loss: 0.74470 Accuracy: 0.00000
Epoch: 292 Batch:   0 Loss: 0.74464 Accuracy: 0.00000
Epoch: 293 Batch:   0 Loss: 0.74458 Accuracy: 0.00000
Epoch: 294 Batch:   0 Loss: 0.74452 Accuracy: 0.00000
Epoch: 295 Batch:   0 Loss: 0.74446 Accuracy: 0.00000
Epoch: 296 Batch:   0 Loss: 0.74441 Accuracy: 0.00000
Epoch: 297 Batch:   0 Loss: 0.74435 Accuracy: 0.00000
Epoch: 298 Batch:   0 Loss: 0.74429 Accuracy: 0.00000
Epoch: 299 Batch:   0 Loss: 0.74424 Accuracy: 0.00000
Test Loss: 0.70953 Accuracy: 0.50080
Epoch: 300 Batch:   0 Loss: 0.74419 Accuracy: 0.00000
Epoch: 301 Batch:   0 Loss: 0.74413 Accuracy: 0.00000
Epoch: 302 Batch:   0 Loss: 0.74408 Accuracy: 0.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.73985 Accuracy: 0.00000
Test Loss: 0.72978 Accuracy: 0.50100
Epoch: 430 Batch:   0 Loss: 0.73983 Accuracy: 0.00000
Epoch: 431 Batch:   0 Loss: 0.73981 Accuracy: 0.00000
Epoch: 432 Batch:   0 Loss: 0.73979 Accuracy: 0.00000
Epoch: 433 Batch:   0 Loss: 0.73977 Accuracy: 0.00000
Epoch: 434 Batch:   0 Loss: 0.73976 Accuracy: 0.00000
Epoch: 435 Batch:   0 Loss: 0.73974 Accuracy: 0.00000
Epoch: 436 Batch:   0 Loss: 0.73972 Accuracy: 0.00000
Epoch: 437 Batch:   0 Loss: 0.73970 Accuracy: 0.00000
Epoch: 438 Batch:   0 Loss: 0.73968 Accuracy: 0.00000
Epoch: 439 Batch:   0 Loss: 0.73967 Accuracy: 0.00000
Test Loss: 0.73106 Accuracy: 0.50100
Epoch: 440 Batch:   0 Loss: 0.73965 Accuracy: 0.00000
Epoch: 441 Batch:   0 Loss: 0.73963 Accuracy: 0.00000
Epoch: 442 Batch:   0 Loss: 0.73962 Accuracy: 0.00000
Epoch: 443 Batch:   0 Loss: 0.73960 Accuracy: 0.00000
Epoch: 444 Batch:   0 Loss: 0.73958 Accuracy: 0.00000
Epoch: 445 Batch:   0 Loss: 0.73956 Accuracy: 0.00000
Epoch: 4

[I 2023-11-17 14:46:03,644] Trial 18 finished with value: 0.5008041817450744 and parameters: {'REG_W': 2.7986508078728687e-06, 'REG_B': 0.0016520481634988342, 'REG_Z': 4.701659115784846e-05, 'SPAR_W': 0.9952675904405462, 'SPAR_B': 0.899331900603328, 'SPAR_Z': 0.8610259893007911, 'LEARNING_RATE': 0.0005469267034550651, 'NUM_EPOCHS': 505}. Best is trial 10 with value: 0.7833735424205871.


Epoch:   0 Batch:   0 Loss: 0.05067 Accuracy: 1.00000
Epoch:   1 Batch:   0 Loss: 0.04223 Accuracy: 1.00000
Epoch:   2 Batch:   0 Loss: 0.36415 Accuracy: 1.00000
Epoch:   3 Batch:   0 Loss: 1.36094 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.59576 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.64890 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.67048 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.67943 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.68064 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.67594 Accuracy: 0.00000
Test Loss: 0.95305 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.66648 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.65380 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.63879 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.62171 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.60405 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.58575 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.56746 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.54984 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.85687 Accuracy: 0.00000
Epoch: 144 Batch:   0 Loss: 0.85624 Accuracy: 0.00000
Epoch: 145 Batch:   0 Loss: 0.85563 Accuracy: 0.00000
Epoch: 146 Batch:   0 Loss: 0.85503 Accuracy: 0.00000
Epoch: 147 Batch:   0 Loss: 0.85443 Accuracy: 0.00000
Epoch: 148 Batch:   0 Loss: 0.85386 Accuracy: 0.00000
Epoch: 149 Batch:   0 Loss: 0.85329 Accuracy: 0.00000
Test Loss: 0.86910 Accuracy: 0.52814
Epoch: 150 Batch:   0 Loss: 0.85274 Accuracy: 0.00000
Epoch: 151 Batch:   0 Loss: 0.85220 Accuracy: 0.00000
Epoch: 152 Batch:   0 Loss: 0.85167 Accuracy: 0.00000
Epoch: 153 Batch:   0 Loss: 0.85115 Accuracy: 0.00000
Epoch: 154 Batch:   0 Loss: 0.85065 Accuracy: 0.00000
Epoch: 155 Batch:   0 Loss: 0.85017 Accuracy: 0.00000
Epoch: 156 Batch:   0 Loss: 0.84969 Accuracy: 0.00000
Epoch: 157 Batch:   0 Loss: 0.84922 Accuracy: 0.00000
Epoch: 158 Batch:   0 Loss: 0.84876 Accuracy: 0.00000
Epoch: 159 Batch:   0 Loss: 0.84831 Accuracy: 0.00000
Test Loss: 0.86741 Accuracy: 0.52863
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.81405 Accuracy: 0.18750
Epoch: 287 Batch:   0 Loss: 0.81376 Accuracy: 0.18750
Epoch: 288 Batch:   0 Loss: 0.81348 Accuracy: 0.18750
Epoch: 289 Batch:   0 Loss: 0.81319 Accuracy: 0.21875
Test Loss: 0.86072 Accuracy: 0.55124
Epoch: 290 Batch:   0 Loss: 0.81291 Accuracy: 0.21875
Epoch: 291 Batch:   0 Loss: 0.81262 Accuracy: 0.21875
Epoch: 292 Batch:   0 Loss: 0.81233 Accuracy: 0.21875
Epoch: 293 Batch:   0 Loss: 0.81204 Accuracy: 0.25000
Epoch: 294 Batch:   0 Loss: 0.81175 Accuracy: 0.25000
Epoch: 295 Batch:   0 Loss: 0.81146 Accuracy: 0.25000
Epoch: 296 Batch:   0 Loss: 0.81116 Accuracy: 0.25000
Epoch: 297 Batch:   0 Loss: 0.81087 Accuracy: 0.25000
Epoch: 298 Batch:   0 Loss: 0.81057 Accuracy: 0.28125
Epoch: 299 Batch:   0 Loss: 0.81028 Accuracy: 0.28125
Test Loss: 0.85991 Accuracy: 0.55445
Epoch: 300 Batch:   0 Loss: 0.80998 Accuracy: 0.28125
Epoch: 301 Batch:   0 Loss: 0.80968 Accuracy: 0.28125
Epoch: 302 Batch:   0 Loss: 0.80939 Accuracy: 0.28125
Epoch: 3

[I 2023-11-17 14:51:23,448] Trial 19 finished with value: 0.5906714917571371 and parameters: {'REG_W': 3.4256544426719675e-06, 'REG_B': 0.0012195362651619677, 'REG_Z': 4.640666946458155e-05, 'SPAR_W': 0.9249592321637052, 'SPAR_B': 0.8062760865390182, 'SPAR_Z': 0.9531998601449847, 'LEARNING_RATE': 0.0002936820967518912, 'NUM_EPOCHS': 412}. Best is trial 10 with value: 0.7833735424205871.


In [14]:
study.best_trials

[FrozenTrial(number=10, state=TrialState.COMPLETE, values=[0.7833735424205871], datetime_start=datetime.datetime(2023, 11, 17, 13, 58, 41, 309742), datetime_complete=datetime.datetime(2023, 11, 17, 14, 4, 12, 447046), params={'REG_W': 3.1651209759814418e-06, 'REG_B': 1.5029243127284535e-05, 'REG_Z': 4.9678837029664995e-05, 'SPAR_W': 0.9694619549595722, 'SPAR_B': 0.8063620879846112, 'SPAR_Z': 0.9979067852308309, 'LEARNING_RATE': 0.0001202698283373907, 'NUM_EPOCHS': 484}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'REG_W': FloatDistribution(high=5e-06, log=False, low=2e-06, step=None), 'REG_B': FloatDistribution(high=0.01, log=False, low=0.0, step=None), 'REG_Z': FloatDistribution(high=5e-05, log=False, low=2e-05, step=None), 'SPAR_W': FloatDistribution(high=1.0, log=False, low=0.5, step=None), 'SPAR_B': FloatDistribution(high=1.0, log=False, low=0.5, step=None), 'SPAR_Z': FloatDistribution(high=1.0, log=False, low=0.5, step=None), 'LEARNING_RATE': FloatDistri

In [15]:
PROJECTION_DIM = 5 #d^
NUM_PROTOTYPES = 40 #m
REG_W = 3.1545272147130644e-06
REG_B = 0.00025738817748430687
REG_Z = 3.534809578054365e-05
SPAR_W = 0.6622823978261199
SPAR_B = 0.651918182305612
SPAR_Z = 0.6197401480486052
LEARNING_RATE = 0.00029791501581002825
NUM_EPOCHS = 595
BATCH_SIZE = 32
GAMMA = 0.007586

# Model Training

In [16]:

# Setup input and train protoNN


X = tf.placeholder(tf.float32, [None, dataDimension], name='X')
Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')
protoNN = ProtoNN(dataDimension, PROJECTION_DIM,
                  NUM_PROTOTYPES, numClasses,
                  gamma, W=W, B=B)
trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,
                         SPAR_W, SPAR_B, SPAR_Z,
                         LEARNING_RATE, X, Y, lossType='xentropy')
sess = tf.Session()

trainer.train(BATCH_SIZE, NUM_EPOCHS, sess, x_train, x_test, y_train, y_test,
              printStep=600, valStep=10)


Epoch:   0 Batch:   0 Loss: 1.96731 Accuracy: 0.00000
Epoch:   1 Batch:   0 Loss: 2.25109 Accuracy: 0.00000
Epoch:   2 Batch:   0 Loss: 1.97812 Accuracy: 0.00000
Epoch:   3 Batch:   0 Loss: 1.88018 Accuracy: 0.00000
Epoch:   4 Batch:   0 Loss: 1.84649 Accuracy: 0.00000
Epoch:   5 Batch:   0 Loss: 1.83512 Accuracy: 0.00000
Epoch:   6 Batch:   0 Loss: 1.83033 Accuracy: 0.00000
Epoch:   7 Batch:   0 Loss: 1.82631 Accuracy: 0.00000
Epoch:   8 Batch:   0 Loss: 1.82104 Accuracy: 0.00000
Epoch:   9 Batch:   0 Loss: 1.81396 Accuracy: 0.00000
Test Loss: 0.96558 Accuracy: 0.50020
Epoch:  10 Batch:   0 Loss: 1.80509 Accuracy: 0.00000
Epoch:  11 Batch:   0 Loss: 1.79451 Accuracy: 0.00000
Epoch:  12 Batch:   0 Loss: 1.78257 Accuracy: 0.00000
Epoch:  13 Batch:   0 Loss: 1.76965 Accuracy: 0.00000
Epoch:  14 Batch:   0 Loss: 1.75536 Accuracy: 0.00000
Epoch:  15 Batch:   0 Loss: 1.74069 Accuracy: 0.00000
Epoch:  16 Batch:   0 Loss: 1.72541 Accuracy: 0.00000
Epoch:  17 Batch:   0 Loss: 1.70923 Accuracy:

Epoch: 143 Batch:   0 Loss: 0.59260 Accuracy: 1.00000
Epoch: 144 Batch:   0 Loss: 0.59133 Accuracy: 1.00000
Epoch: 145 Batch:   0 Loss: 0.59007 Accuracy: 1.00000
Epoch: 146 Batch:   0 Loss: 0.58889 Accuracy: 1.00000
Epoch: 147 Batch:   0 Loss: 0.58771 Accuracy: 1.00000
Epoch: 148 Batch:   0 Loss: 0.58655 Accuracy: 1.00000
Epoch: 149 Batch:   0 Loss: 0.58539 Accuracy: 1.00000
Test Loss: 0.76412 Accuracy: 0.67794
Epoch: 150 Batch:   0 Loss: 0.58414 Accuracy: 1.00000
Epoch: 151 Batch:   0 Loss: 0.58292 Accuracy: 1.00000
Epoch: 152 Batch:   0 Loss: 0.58173 Accuracy: 1.00000
Epoch: 153 Batch:   0 Loss: 0.58055 Accuracy: 1.00000
Epoch: 154 Batch:   0 Loss: 0.57940 Accuracy: 1.00000
Epoch: 155 Batch:   0 Loss: 0.57824 Accuracy: 1.00000
Epoch: 156 Batch:   0 Loss: 0.57713 Accuracy: 1.00000
Epoch: 157 Batch:   0 Loss: 0.57601 Accuracy: 1.00000
Epoch: 158 Batch:   0 Loss: 0.57492 Accuracy: 1.00000
Epoch: 159 Batch:   0 Loss: 0.57382 Accuracy: 1.00000
Test Loss: 0.76141 Accuracy: 0.68317
Epoch: 1

Epoch: 286 Batch:   0 Loss: 0.48397 Accuracy: 1.00000
Epoch: 287 Batch:   0 Loss: 0.48346 Accuracy: 1.00000
Epoch: 288 Batch:   0 Loss: 0.48295 Accuracy: 1.00000
Epoch: 289 Batch:   0 Loss: 0.48245 Accuracy: 1.00000
Test Loss: 0.73289 Accuracy: 0.72416
Epoch: 290 Batch:   0 Loss: 0.48194 Accuracy: 1.00000
Epoch: 291 Batch:   0 Loss: 0.48146 Accuracy: 1.00000
Epoch: 292 Batch:   0 Loss: 0.48093 Accuracy: 1.00000
Epoch: 293 Batch:   0 Loss: 0.48042 Accuracy: 1.00000
Epoch: 294 Batch:   0 Loss: 0.47992 Accuracy: 1.00000
Epoch: 295 Batch:   0 Loss: 0.47937 Accuracy: 1.00000
Epoch: 296 Batch:   0 Loss: 0.47885 Accuracy: 1.00000
Epoch: 297 Batch:   0 Loss: 0.47834 Accuracy: 1.00000
Epoch: 298 Batch:   0 Loss: 0.47782 Accuracy: 1.00000
Epoch: 299 Batch:   0 Loss: 0.47732 Accuracy: 1.00000
Test Loss: 0.73106 Accuracy: 0.72586
Epoch: 300 Batch:   0 Loss: 0.47683 Accuracy: 1.00000
Epoch: 301 Batch:   0 Loss: 0.47636 Accuracy: 1.00000
Epoch: 302 Batch:   0 Loss: 0.47588 Accuracy: 1.00000
Epoch: 3

Epoch: 429 Batch:   0 Loss: 0.42628 Accuracy: 1.00000
Test Loss: 0.71115 Accuracy: 0.74384
Epoch: 430 Batch:   0 Loss: 0.42595 Accuracy: 1.00000
Epoch: 431 Batch:   0 Loss: 0.42562 Accuracy: 1.00000
Epoch: 432 Batch:   0 Loss: 0.42528 Accuracy: 1.00000
Epoch: 433 Batch:   0 Loss: 0.42495 Accuracy: 1.00000
Epoch: 434 Batch:   0 Loss: 0.42465 Accuracy: 1.00000
Epoch: 435 Batch:   0 Loss: 0.42431 Accuracy: 1.00000
Epoch: 436 Batch:   0 Loss: 0.42396 Accuracy: 1.00000
Epoch: 437 Batch:   0 Loss: 0.42363 Accuracy: 1.00000
Epoch: 438 Batch:   0 Loss: 0.42330 Accuracy: 1.00000
Epoch: 439 Batch:   0 Loss: 0.42296 Accuracy: 1.00000
Test Loss: 0.70998 Accuracy: 0.74505
Epoch: 440 Batch:   0 Loss: 0.42264 Accuracy: 1.00000
Epoch: 441 Batch:   0 Loss: 0.42231 Accuracy: 1.00000
Epoch: 442 Batch:   0 Loss: 0.42198 Accuracy: 1.00000
Epoch: 443 Batch:   0 Loss: 0.42165 Accuracy: 1.00000
Epoch: 444 Batch:   0 Loss: 0.42133 Accuracy: 1.00000
Epoch: 445 Batch:   0 Loss: 0.42100 Accuracy: 1.00000
Epoch: 4

Epoch: 571 Batch:   0 Loss: 0.38455 Accuracy: 1.00000
Epoch: 572 Batch:   0 Loss: 0.38429 Accuracy: 1.00000
Epoch: 573 Batch:   0 Loss: 0.38402 Accuracy: 1.00000
Epoch: 574 Batch:   0 Loss: 0.38377 Accuracy: 1.00000
Epoch: 575 Batch:   0 Loss: 0.38351 Accuracy: 1.00000
Epoch: 576 Batch:   0 Loss: 0.38325 Accuracy: 1.00000
Epoch: 577 Batch:   0 Loss: 0.38299 Accuracy: 1.00000
Epoch: 578 Batch:   0 Loss: 0.38274 Accuracy: 1.00000
Epoch: 579 Batch:   0 Loss: 0.38248 Accuracy: 1.00000
Test Loss: 0.69762 Accuracy: 0.75399
Epoch: 580 Batch:   0 Loss: 0.38223 Accuracy: 1.00000
Epoch: 581 Batch:   0 Loss: 0.38197 Accuracy: 1.00000
Epoch: 582 Batch:   0 Loss: 0.38172 Accuracy: 1.00000
Epoch: 583 Batch:   0 Loss: 0.38147 Accuracy: 1.00000
Epoch: 584 Batch:   0 Loss: 0.38121 Accuracy: 1.00000
Epoch: 585 Batch:   0 Loss: 0.38096 Accuracy: 1.00000
Epoch: 586 Batch:   0 Loss: 0.38071 Accuracy: 1.00000
Epoch: 587 Batch:   0 Loss: 0.38046 Accuracy: 1.00000
Epoch: 588 Batch:   0 Loss: 0.38020 Accuracy:

# Model Evaluation

In [17]:
acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test})
pred = sess.run(protoNN.predictions, feed_dict={X: x_test, Y: y_test})
# W, B, Z are tensorflow graph nodes
W, B, Z, _ = protoNN.getModelMatrices()
matrixList = sess.run([W, B, Z])
sparcityList = [SPAR_W, SPAR_B, SPAR_Z]                       
nnz, size, sparse = getModelSize(matrixList, sparcityList)
print("Final test accuracy", acc)
print("Model size constraint (Bytes): ", size)
print("Number of non-zeros: ", nnz)

Final test accuracy 0.7552272
Model size constraint (Bytes):  9580
Number of non-zeros:  2395


In [18]:
from sklearn.metrics import confusion_matrix,classification_report
y_test = np.argmax(y_test,axis=1)
print (confusion_matrix(y_test,pred))
print (classification_report(y_test,pred,digits=5))

[[2946 2028]
 [ 407 4567]]
              precision    recall  f1-score   support

           0    0.87862   0.59228   0.70758      4974
           1    0.69249   0.91817   0.78952      4974

    accuracy                        0.75523      9948
   macro avg    0.78556   0.75523   0.74855      9948
weighted avg    0.78556   0.75523   0.74855      9948



In [19]:
sensitivity = confusion_matrix(y_test,pred)[1][1]/(confusion_matrix(y_test,pred)[1][1] + confusion_matrix(y_test,pred)[1][0])
sensitivity

0.9181745074386811

In [20]:
specificity = confusion_matrix(y_test,pred)[0][0]/(confusion_matrix(y_test,pred)[0][0] + confusion_matrix(y_test,pred)[0][1])
specificity

0.5922798552472859

In [21]:
# accuracy = (TP + TN) / (TP + TN + FP + FN) 
# Precision = TP / (TP + FP)
# Recall = TP / (TP + FN)
# sensitivity = TP / (TP + FN) 
# specificity = TN / (TN + FP) 
# positive predictive value (PPV) = TP / (TP + FP)