In [41]:
from typing import Optional, Union, Tuple, List, Sequence, Iterable

import os
import torch
import numpy as np
from time import time as t
import h5py

from sklearn.model_selection import ParameterGrid
from bindsnet.network.monitors import Monitor
from bindsnet.network.nodes import Input, LIFNodes, DiehlAndCookNodes, AdaptiveLIFNodes
from bindsnet.network import Network
from bindsnet.network.topology import Connection
from bindsnet.network.nodes import Input, LIFNodes, AdaptiveLIFNodes

from bindsnet.learning import PostPre
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.analysis.plotting import (
    plot_input,
    plot_assignments,
    plot_performance,
    plot_weights,
    plot_spikes,
    plot_voltages,
)

In [42]:
n_folds = 0
test_size = 0.2
seed = 0
n_neurons = 100
n_clamp = 1
exc = 2.5
inh = 22.5
time = 255
n_dim = 64
dt = 1.0
progress_interval = 10
update_interval = 25
train = True
plot = True
gpu = False
n_class = 2
lr_pre = 1
lr_post = 1
nu = [lr_pre, lr_post]
norm = 0.1
theta_plus = 0.05,
tc_theta_decay = 1e7
n_epochs = 1
percentage_of_test_set=1
if gpu:
    torch.set_default_tensor_type("torch.cuda.FloatTensor")
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)


from bindsnet.network import load
from bindsnet.learning import NoOp

from torch.utils.data import Dataset
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedShuffleSplit

# In[4]:

In [43]:
train_size = 0.8
test_size = 0.2
n_folds = 4
n_epochs = 1


# In[5]:

In [44]:
class EEGAlcoDatasetBalanced(Dataset):
    """EEG Train dataset."""

    def __init__(self, testingByDatapoint, returnObservation=None, observationsSize=None):
        """
        Args:
            none.
        """
        self.testingByDatapoint = testingByDatapoint
        self.returnObservation = returnObservation
        self.observationsSize = observationsSize
        h5f = h5py.File('alco_scalars_balanced_X_train.h5','r')
        self.spikes_seizure_eeg_train = h5f['dataset_alco_scalars_balanced_X_train'][:]
        #self.spikes_seizure_eeg_train=np.swapaxes(self.spikes_seizure_eeg_train,1,2)
        h5f.close()
        
        h5f = h5py.File('alco_scalars_balanced_X_test.h5','r')
        self.spikes_seizure_eeg_test = h5f['dataset_alco_scalars_balanced_X_test'][:]
        #self.spikes_seizure_eeg_test=np.swapaxes(self.spikes_seizure_eeg_test,1,2)

        #print(str(np.sum(self.labels_seizure_eeg))+'/'+str(len(self.labels_seizure_eeg)))
        h5f.close()
        
        h5f = h5py.File('alco_balanced_y_train.h5','r')
        self.labels_seizure_eeg_train = h5f['dataset_alco_balanced_y_train'][:]
        #self.spikes_seizure_eeg=np.swapaxes(self.spikes_seizure_eeg,1,2)
        h5f.close()
        
        h5f = h5py.File('alco_balanced_y_test.h5','r')
        self.labels_seizure_eeg_test = h5f['dataset_alco_balanced_y_test'][:]
        #self.spikes_seizure_eeg=np.swapaxes(self.spikes_seizure_eeg,1,2)
        h5f.close()

    def get_data(self):
        #all folds
        dataArray = list()

        trainLabels=self.labels_seizure_eeg_train
        trainValues=self.spikes_seizure_eeg_train  
        testLabels=self.labels_seizure_eeg_test
        testValues=self.spikes_seizure_eeg_test

        shuffle = np.random.RandomState(seed=0).permutation(len(trainValues))
        trainValues = trainValues[shuffle]
        trainLabels = trainLabels[shuffle]
        if (self.testingByDatapoint == True):
            trainValues = trainValues[self.returnObservation * self.observationsSize:
                                      self.returnObservation * self.observationsSize + self.observationsSize]
            # trainValues=trainValues.reshape((1,trainValues.shape[0],trainValues.shape[1]))
            #print('trainLabels: ' + str(trainLabels))
            trainLabels = trainLabels[self.returnObservation * self.observationsSize:
                                      self.returnObservation * self.observationsSize + self.observationsSize]
            #print('trainLabels: ' + str(trainLabels))


        currentSplit = {'X_train': torch.tensor(trainValues), 'X_test': torch.tensor(testValues), 
                        'y_train': torch.tensor(trainLabels), 'y_test': torch.tensor(testLabels)}
        dataArray.append(currentSplit)
        return dataArray

    
    def __len__(self):
        return len(self.spikes_seizure_eeg_train)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        eeg = torch.tensor(self.spikes_seizure_eeg_train[idx])
        print('eeg size (in getitem): '+str(eeg.size()))
        label = self.labels_seizure_eeg_train[idx]
            
        sample = {'eeg': eeg, 'label': label}
        return sample

In [45]:
# In[6]:


dataset = EEGAlcoDatasetBalanced(testingByDatapoint=False)
# Create a dataloader to iterate and batch data
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=gpu)

dataArray = dataset.get_data()
n_train = dataArray[0]['X_train'].shape[0]
n_test = dataArray[0]['X_test'].shape[0]

# In[ ]:


# In[8]:


model = 'diehl_and_cook_2015'
data = 'alco_snn'

data_path = os.path.join('data', 'seizure')
params_path = os.path.join('params', data, model)
spikes_path = os.path.join('spikes', data, model)
curves_path = os.path.join('curves', data, model)
results_path = os.path.join('results', data, model)
confusion_path = os.path.join('confusion', data, model)

for path in [params_path, spikes_path, curves_path, results_path, confusion_path]:
    if not os.path.isdir(path):
        os.makedirs(path)

In [46]:
# In[9]:


class NewNetwork(Network):
    # language=rst
    """
    Implements the spiking neural network architecture from `(Diehl & Cook 2015)
    <https://www.frontiersin.org/articles/10.3389/fncom.2015.00099/full>`_.
    """

    def __init__(
            self,
            n_inpt: int,
            n_neurons: int = 100,
            exc: float = 22.5,
            inh: float = 17.5,
            dt: float = 1.0,
            nu: Optional[Union[float, Sequence[float]]] = (1e-4, 1e-2),
            reduction: Optional[callable] = None,
            wmin: float = 0.0,
            wmax: float = 1.0,
            norm: float = 78.4,
            theta_plus: float = 0.05,
            tc_theta_decay: float = 1e7,
            inpt_shape: Optional[Iterable[int]] = None,
    ) -> None:
        # language=rst
        """
        Constructor for class ``DiehlAndCook2015``.
        :param n_inpt: Number of input neurons. Matches the 1D size of the input data.
        :param n_neurons: Number of excitatory, inhibitory neurons.
        :param exc: Strength of synapse weights from excitatory to inhibitory layer.
        :param inh: Strength of synapse weights from inhibitory to excitatory layer.
        :param dt: Simulation time step.
        :param nu: Single or pair of learning rates for pre- and post-synaptic events,
            respectively.
        :param reduction: Method for reducing parameter updates along the minibatch
            dimension.
        :param wmin: Minimum allowed weight on input to excitatory synapses.
        :param wmax: Maximum allowed weight on input to excitatory synapses.
        :param norm: Input to excitatory layer connection weights normalization
            constant.
        :param theta_plus: On-spike increment of ``DiehlAndCookNodes`` membrane
            threshold potential.
        :param tc_theta_decay: Time constant of ``DiehlAndCookNodes`` threshold
            potential decay.
        :param inpt_shape: The dimensionality of the input layer.
        """
        super().__init__(dt=dt)

        self.n_inpt = n_inpt
        self.inpt_shape = inpt_shape
        self.n_neurons = n_neurons
        self.exc = exc
        self.inh = inh
        self.dt = dt

        # Layers
        input_layer = Input(
            n=self.n_inpt, shape=self.inpt_shape, traces=True, tc_trace=20.0
        )
        exc_layer = DiehlAndCookNodes(
            n=self.n_neurons,
            traces=True,
            rest=-65.0,
            reset=-60.0,
            thresh=-52.0,
            refrac=5,
            tc_decay=100.0,
            tc_trace=20.0,
            theta_plus=theta_plus,
            tc_theta_decay=tc_theta_decay,
        )
        inh_layer = LIFNodes(
            n=self.n_neurons,
            traces=False,
            rest=-60.0,
            reset=-45.0,
            thresh=-40.0,
            tc_decay=10.0,
            refrac=2,
            tc_trace=20.0,
        )

        # Connections
        w = 0.3 * torch.rand(self.n_inpt, self.n_neurons)
        input_exc_conn = Connection(
            source=input_layer,
            target=exc_layer,
            w=w,
            update_rule=PostPre,
            nu=nu,
            reduction=reduction,
            wmin=wmin,
            wmax=wmax,
            norm=norm,
        )
        w = self.exc * torch.diag(torch.ones(self.n_neurons))
        exc_inh_conn = Connection(
            source=exc_layer, target=inh_layer, w=w, wmin=0, wmax=self.exc
        )
        w = -self.inh * (
                torch.ones(self.n_neurons, self.n_neurons)
                - torch.diag(torch.ones(self.n_neurons))
        )
        inh_exc_conn = Connection(
            source=inh_layer, target=exc_layer, w=w, wmin=-self.inh, wmax=0
        )

        # Add to network
        self.add_layer(input_layer, name="X")
        self.add_layer(exc_layer, name="Ae")
        self.add_layer(inh_layer, name="Ai")
        self.add_connection(input_exc_conn, source="X", target="Ae")
        self.add_connection(exc_inh_conn, source="Ae", target="Ai")
        self.add_connection(inh_exc_conn, source="Ai", target="Ae")


# In[29]:

In [66]:
def createSNN(seed=0, singleSample=False, n_neurons=n_neurons, n_train=n_train, n_test=n_test, inh=inh, exc=exc,
              lr_pre=lr_pre,
              lr_post=lr_post, time=time, dt=dt, norm=norm, intensity=30, progress_interval=progress_interval,
              update_interval=update_interval, plot=True, train=True, gpu=False, current_fold=0, current_epoch=0):
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    per_class = int(n_neurons / n_class)

    params = [
        seed, n_neurons, inh, exc, lr_pre, lr_post, time, dt, norm,
        intensity, progress_interval, current_fold
    ]
    print('params: ' + str(params))

    model_name = '_'.join([str(x) for x in params])

    dataset = EEGAlcoDatasetBalanced(testingByDatapoint=singleSample, returnObservation=current_epoch, observationsSize=n_train)
    # Create a dataloader to iterate and batch data


    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=gpu)
    dataArray = dataset.get_data()

    totalTrainingTimeSecs = list()
    # nu: Single or pair of learning rates for pre- and post-synaptic events, respectively.
    if (train == True):
        if (current_epoch == 0):
            print('starting training (new network): ')

            network = NewNetwork(
                n_inpt=n_dim,
                n_neurons=n_neurons,
                exc=exc,
                inh=inh,
                dt=dt,
                norm=norm,
                nu=[lr_pre, lr_post],
                inpt_shape=(1, n_dim))
            print('NewNetwork: ' + str(network))
        else:
            print('continuing training (loading network): ')
            print('loading model from: ' + str(os.path.join(params_path, model_name + '.pt')))

            network = load(os.path.join(params_path, model_name + '.pt'), learning=True)

    else:
        print('starting testing: ')
        if os.path.exists(os.path.join(params_path, model_name + '.pt')):
            print('loading model from: ' + str(os.path.join(params_path, model_name + '.pt')))
            network = load(os.path.join(params_path, model_name + '.pt'), learning=False)
            network.connections['X', 'Ae'].update_rule = NoOp(
                connection=network.connections['X', 'Ae'], nu=network.connections['X', 'Ae'].nu
            )
            network.layers['Ae'].tc_theta_decay = torch.Tensor([0])
            network.layers['Ae'].theta_plus = torch.Tensor([0])

            network.connections['Ae', 'Ai'].update_rule = NoOp(
                connection=network.connections['Ae', 'Ai'], nu=network.connections['Ae', 'Ai'].nu
            )
            network.layers['Ai'].tc_theta_decay = torch.Tensor([0])
            network.layers['Ai'].theta_plus = torch.Tensor([0])

            network.connections['Ai', 'Ae'].update_rule = NoOp(
                connection=network.connections['Ai', 'Ae'], nu=network.connections['Ai', 'Ae'].nu
            )
        else:
            network = NewNetwork(
                n_inpt=n_dim,
                n_neurons=n_neurons,
                exc=exc,
                inh=inh,
                dt=dt,
                norm=norm,
                nu=[lr_pre, lr_post],
                inpt_shape=(1, n_dim))
            print('NewNetwork: ' + str(network))
    # network.to("cuda")

    if (train == False):
        update_interval = n_test

    # Voltage recording for excitatory and inhibitory layers.
    exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time)
    inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time)
    network.add_monitor(exc_voltage_monitor, name="exc_voltage")
    network.add_monitor(inh_voltage_monitor, name="inh_voltage")

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)

    # Neuron assignments and spike proportions.
    assignments = -torch.ones_like(torch.Tensor(n_neurons))
    proportions = torch.zeros_like(torch.Tensor(n_neurons, n_class))
    rates = torch.zeros_like(torch.Tensor(n_neurons, n_class))

    # Neuron assignments and spike proportions.
    if (train == True):
        if (current_epoch == 0):
            assignments = -torch.ones_like(torch.Tensor(n_neurons))
            proportions = torch.zeros_like(torch.Tensor(n_neurons, n_class))
            rates = torch.zeros_like(torch.Tensor(n_neurons, n_class))
        else:
            path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt')
            assignments, proportions, rates = torch.load(open(path, 'rb'))
    else:
        if os.path.exists(os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt')):
            path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt')
            assignments, proportions, rates = torch.load(open(path, 'rb'))
        else:
            assignments = -torch.ones_like(torch.Tensor(n_neurons))
            proportions = torch.zeros_like(torch.Tensor(n_neurons, n_class))
            rates = torch.zeros_like(torch.Tensor(n_neurons, n_class))

    # Sequence of accuracy estimates.
    accuracy = {"all": [], "proportion": []}

    batchTime =0
    spikes = {}
    for layer in set(network.layers) - {"X"}:
        spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    for epoch in np.arange(n_epochs):
        print('epoch: ' + str(epoch))

        if train:
            images = dataArray[current_fold]['X_train']
            images *= intensity
            labels = dataArray[current_fold]['y_train']
            print('\nBegin training.\n')
            n_examples = n_train
            print('n_examples: ' + str(n_examples))
        else:
            images = dataArray[current_fold]['X_test']
            images *= intensity
            labels = dataArray[current_fold]['y_test']
            print('\nBegin testing.\n')
            n_examples = n_test
            print('n_examples: ' + str(n_examples))

        all_dataset_predictions_all = np.zeros(shape=(n_examples))
        all_dataset_predictions_prop = np.zeros(shape=(n_examples))

        for i in range(n_examples):
            startOfBatchTime = t()
            print('n_examples: ' + str(n_examples))
            print('i: ' + str(i))
            print('update_interval: ' + str(update_interval))
            if train:
                print("Train progress: (%d / %d)" % (i, n_examples))
            else:
                print("Test progress: (%d / %d)" % (i, n_examples))
            print('len(images): ' + str(len(images)))
            print('(images).shape: ' + str((images).shape))

            image = images[i % len(images)]
            print('(labels): ' + str(labels))
            
            label = labels[i % len(labels)]
        
            # print('Current (i, dataPoint): '+str((i, dataPoint)))
            if (train == True):
                if i > n_train:
                    break
            else:
                if i > n_test:
                    break
            # image = dataPoint["eeg"]
            # label = dataPoint["label"]
            print('current label: ' + str(label))

            # Optionally plot various simulation information.

            # Run the network on the input.
            choice = np.random.choice(int(n_neurons / n_class), size=n_clamp, replace=False)
            clamp = {"Ae": per_class * label.long() + torch.Tensor(choice).long()}
            inputs = {"X": image.view(time, n_dim)}

            #print('inputs: '+str(inputs))
            network.run(inputs=inputs, time=time, clamp=clamp)

            # Get voltage recording.
            exc_voltages = exc_voltage_monitor.get("v")
            inh_voltages = inh_voltage_monitor.get("v")

            voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

            #print('spikes["Ae"]: ' + str(spikes["Ae"]))
            spike_record[i % update_interval] = spikes["Ae"].get("s").view(time, n_neurons)

            # print('all_labels_in_batch before network reset '+str(all_labels_in_batch))

            network.reset_state_variables()  # Reset state variables.

            if ((i) % update_interval == 0 and train==True) or ((i+1) % (update_interval) == 0 and train==False ):
                if (train):
                    update_labels = label
                else:
                    update_labels = labels
                #print('(i) % update_interval: ' + str(i))
                # Get network predictions.
                all_activity_pred = all_activity(spike_record, assignments, n_class)
                #print('predictions: ' + str(all_activity_pred))
                #print('labels: ' + str(update_labels.long()))
                proportion_pred = proportion_weighting(
                    spike_record, assignments, proportions, n_class
                )

                # Compute network accuracy according to available classification strategies.
                accuracy["all"].append(
                    100 * torch.sum(update_labels.long() == all_activity_pred.cpu()).item() / update_interval
                )
                accuracy["proportion"].append(
                    100 * torch.sum(update_labels.long() == proportion_pred.cpu()).item() / update_interval
                )

                print(
                    "\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
                    % (accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]))
                )
                print(
                    "Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n"
                    % (
                        accuracy["proportion"][-1],
                        np.mean(accuracy["proportion"]),
                        np.max(accuracy["proportion"]),
                    )
                )

                # Assign labels to excitatory layer neurons.
                #print('updating assignments: ')
                #print('spike_record: ' + str(spike_record))
                #print('rates: ' + str(rates))

                assignments, proportions, rates = assign_labels(spike_record, label.view(1), 2, rates, 1)
            # print('network layers: '+str(network.layers))
            if(train==True):
                batchTime=(t() - startOfBatchTime)
                totalTrainingTimeSecs.append(batchTime)


        # print("Progress: %d / %d \n" % (n_train, n_train))
        print("Training complete.\n")

    # Save network to disk.
    if train:
        print('saving model to: ' + str(os.path.join(params_path, model_name + '.pt')))
        network.save(os.path.join(params_path, model_name + '.pt'))
        path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt')
        torch.save((assignments, proportions, rates), open(path, 'wb'))

    if(train==False):
        # Compute confusion matrices .
        confusions = {}
        confusions['prop'] = confusion_matrix(labels.cpu(), proportion_pred)
        confusions['all'] = confusion_matrix(labels.cpu(), all_activity_pred)
        print('all confusion_matrix(labels, all_activity_pred): ' + str(
            confusion_matrix(labels, all_activity_pred)))
        print('prop confusion_matrix(labels, proportion_pred): ' + str(
            confusion_matrix(labels, proportion_pred)))
        return confusions, totalTrainingTimeSecs
    else:
        return 0, totalTrainingTimeSecs

In [67]:
param_grid = {'n_neurons': [5, 10], 'lr_pre': [0.1, 1], 'lr_post': [0.1, 1], 'exc': [5, 25, 125],
              'inh': [5, 25, 125], 'update_interval': [1], 'n_epochs': [1]}
param_grid = {'exc': [5], 'inh': [5], 'lr_post': [0.1], 'lr_pre': [0.1], 'n_epochs': [1], 
              'n_neurons': [5], 'update_interval': [1]}
#Opt Params
grid = ParameterGrid(param_grid)

In [68]:
for i, params in enumerate(grid):
    print(str(i)+': '+str(params))

0: {'exc': 5, 'inh': 5, 'lr_post': 0.1, 'lr_pre': 0.1, 'n_epochs': 1, 'n_neurons': 5, 'update_interval': 1}


In [60]:
#FINAL MODEL (EVALUATE AFTER EVERY TRAINING SAMPLE)

n_train_full=dataset.get_data()[0]['X_train'].shape[0]
n_test_full=dataset.get_data()[0]['X_test'].shape[0]
size_of_train_set = 1
train_set_update_interval = 1
resultsFilename = str(data)+"_results_data_OPTIMAL.txt"
if os.path.isfile(resultsFilename):
    print ("Results file exists")
else:
    print ("Results file doesn't exist, creating new file...")
starting_point=0
for i, params in enumerate(grid):
    if(i>=starting_point):
        overallPrecisionListAll = list()
        overallRecallListAll = list()
        overallAccuracyListAll = list()
        overallF1ListAll = list()
        overallTPListAll = list()
        overallTNListAll = list()
        overallFPListAll = list()
        overallFNListAll = list()
        overallConvergenceEpochAll = list()

        overallPrecisionListProp = list()
        overallRecallListProp = list()
        overallAccuracyListProp = list()
        overallF1ListProp = list()
        overallTPListProp = list()
        overallTNListProp = list()
        overallFPListProp = list()
        overallFNListProp = list()
        overallConvergenceEpochProp = list()


        trainingTimeList = list()
        with open(resultsFilename, "a") as text_file:
            print(f"\nResults for : {str(params)}\n\n", file=text_file)
        confusionMatricesList = list()
        for fold in np.arange(n_folds):
            currentFold = fold
            # Train the network.
            with open(resultsFilename, "a") as text_file:
                print(f"\nFold : {str(currentFold)}\n", file=text_file)
                print(f"Size of Total Training Set : {str(n_train_full)}\n", file=text_file)
                print(f"Eval every {str(size_of_train_set)} observations.\n", file=text_file)
            print("Begin training for fold " + str(currentFold) + "\n")
            start = t()

            foldEpochAccuracyAll = list()
            foldEpochPrecisionAll = list()
            foldEpochRecallAll = list()
            foldEpochF1scoreAll = list()
            foldEpochTPAll = list()
            foldEpochTNAll = list()
            foldEpochFPAll = list()
            foldEpochFNAll = list()

            foldEpochAccuracyProp = list()
            foldEpochPrecisionProp = list()
            foldEpochRecallProp = list()
            foldEpochF1scoreProp = list()
            foldEpochTPProp = list()
            foldEpochTNProp = list()
            foldEpochFPProp = list()
            foldEpochFNProp = list()

            for index in range(int(n_train_full/size_of_train_set)):
                if(index>=8):
                    break
                with open(resultsFilename, "a") as text_file:
                    print(f"Current Training Index : {str(index)}/{str(int(n_train_full/size_of_train_set))}\n", file=text_file)
                dataset = EEGAlcoDatasetBalanced(testingByDatapoint=True, returnObservation=index, observationsSize=size_of_train_set)
                n_train=dataset.get_data()[0]['X_train'].shape[0]
                print('n_train: '+str(n_train))
                n_test=dataset.get_data()[0]['X_test'].shape[0]

                # Create a dataloader to iterate and batch data
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=gpu)
                result, trainingTime = createSNN(train=True, singleSample=True, n_neurons=params['n_neurons'],
                                                 lr_pre=params['lr_pre'],
                                                 inh=params['inh'], lr_post=params['lr_post'], exc=params['exc'],
                                                 update_interval=train_set_update_interval, current_fold=currentFold,
                                                 current_epoch=index, n_train=size_of_train_set)
                print("Training complete for index: " + str(index) + "\n")
                testResult, testingTime = createSNN(train=False, singleSample=False, n_neurons=params['n_neurons'],
                                                    lr_pre=params['lr_pre'],
                                                    inh=params['inh'], lr_post=params['lr_post'], exc=params['exc'],
                                                    update_interval=params['update_interval'], current_fold=currentFold,
                                                    current_epoch=index)

                TPAll = testResult['all'][0][0]
                FPAll = testResult['all'][0][1]
                FNAll = testResult['all'][1][0]
                TNAll = testResult['all'][1][1]
                TPProp = testResult['prop'][0][0]
                FPProp = testResult['prop'][0][1]
                FNProp = testResult['prop'][1][0]
                TNProp = testResult['prop'][1][1]

                precisionAll = (TPAll / (TPAll + FPAll))
                recallAll = (TPAll / (TPAll + FNAll))
                accuracyAll = (TPAll + TNAll) / (TPAll + FPAll + FNAll + TNAll)
                f1scoreAll = 2 * ((precisionAll * recallAll) / (precisionAll + recallAll))
                precisionProp = (TPProp / (TPProp + FPProp))
                recallProp = (TPProp / (TPProp + FNProp))
                accuracyProp = (TPProp + TNProp) / (TPProp + FPProp + FNProp + TNProp)
                f1scoreProp = 2 * ((precisionProp * recallProp) / (precisionProp + recallProp))

                foldEpochAccuracyAll.append(accuracyAll)
                foldEpochPrecisionAll.append(precisionAll)
                foldEpochRecallAll.append(recallAll)
                foldEpochF1scoreAll.append(f1scoreAll)
                foldEpochTPAll.append(TPAll)
                foldEpochTNAll.append(TNAll)
                foldEpochFPAll.append(FPAll)
                foldEpochFNAll.append(FNAll)

                foldEpochAccuracyProp.append(accuracyProp)
                foldEpochPrecisionProp.append(precisionProp)
                foldEpochRecallProp.append(recallProp)
                foldEpochF1scoreProp.append(f1scoreProp)
                foldEpochTPProp.append(TPProp)
                foldEpochTNProp.append(TNProp)
                foldEpochFPProp.append(FPProp)
                foldEpochFNProp.append(FNProp)

            convergenceObservationAccuracyAll=str(foldEpochAccuracyAll.index(max(foldEpochAccuracyAll)))
            convergenceObservationAccuracyProp=str(foldEpochAccuracyProp.index(max(foldEpochAccuracyProp)))

            with open(resultsFilename, "a") as text_file:
                print(f"Training complete for fold {str(currentFold)}\n", file=text_file)
                print(f"Number of Epochs: {str(len(foldEpochAccuracyAll))}\n", file=text_file)
                print(f"Fold {str(currentFold)} Test Metrics:\n", file=text_file)
                print(f"All Spikes:\n", file=text_file)
                print(f"TP: {str(foldEpochTPAll)}\nTN: {str(foldEpochTNAll)}\nFP: {str(foldEpochFPAll)}\nFN: {str(foldEpochFNAll)}\n", file=text_file)
                print(f"Accuracy: {str(foldEpochAccuracyAll)}\nPrecision: {str(foldEpochPrecisionAll)}\nRecall: {str(foldEpochRecallAll)}\nF1: {str(foldEpochF1scoreAll)}\n", file=text_file)
                print(f"Accuracy Maxes at Epoch: {str(convergenceObservationAccuracyAll)}\n", file=text_file)
                print(f"Prop Spikes:\n", file=text_file)
                print(f"TP: {str(foldEpochTPProp)}\nTN: {str(foldEpochTNProp)}\nFP: {str(foldEpochFPProp)}\nFN: {str(foldEpochFNProp)}\n",
                      file=text_file)
                print(f"Accuracy: {str(foldEpochAccuracyProp)}\nPrecision: {str(foldEpochPrecisionProp)}\nRecall: {str(foldEpochRecallProp)}\nF1: {str(foldEpochF1scoreProp)}\n", file=text_file)
                print(f"Accuracy Maxes at Epoch: {convergenceObservationAccuracyProp}\n", file=text_file)

            overallPrecisionListAll.append(foldEpochPrecisionAll)
            overallRecallListAll.append(foldEpochRecallAll)
            overallAccuracyListAll.append(foldEpochAccuracyAll)
            overallF1ListAll.append(foldEpochF1scoreAll)
            overallTPListAll.append(foldEpochTPAll)
            overallTNListAll.append(foldEpochTNAll)
            overallFPListAll.append(foldEpochFPAll)
            overallFNListAll.append(foldEpochFNAll)
            overallConvergenceEpochAll.append(convergenceObservationAccuracyAll)

            overallPrecisionListProp.append(foldEpochPrecisionProp)
            overallRecallListProp.append(foldEpochRecallProp)
            overallAccuracyListProp.append(foldEpochAccuracyProp)
            overallF1ListProp.append(foldEpochF1scoreProp)
            overallTPListProp.append(foldEpochTPProp)
            overallTNListProp.append(foldEpochTNProp)
            overallFPListProp.append(foldEpochFPProp)
            overallFNListProp.append(foldEpochFNProp)
            overallConvergenceEpochAll.append(convergenceObservationAccuracyProp)

        precisionAllMeanOverFolds=np.average(np.array(overallPrecisionListAll), axis=0)
        recallAllMeanOverFolds=np.average(np.array(overallRecallListAll), axis=0)
        accuracyAllMeanOverFolds=np.average(np.array(overallAccuracyListAll), axis=0)
        F1AllMeanOverFolds=np.average(np.array(overallF1ListAll), axis=0)
        TPAllMeanOverFolds = np.average(np.array(overallTPListAll), axis=0)
        TNAllMeanOverFolds = np.average(np.array(overallTNListAll), axis=0)
        FPAllMeanOverFolds = np.average(np.array(overallFPListAll), axis=0)
        FNAllMeanOverFolds = np.average(np.array(overallFNListAll), axis=0)
        convergenceAllMeanOverFolds = np.mean(np.array(overallConvergenceEpochAll).astype(np.float))+1

        precisionPropMeanOverFolds = np.average(np.array(overallPrecisionListProp), axis=0)
        recallPropMeanOverFolds = np.average(np.array(overallRecallListProp), axis=0)
        accuracyPropMeanOverFolds = np.average(np.array(overallAccuracyListProp), axis=0)
        F1PropMeanOverFolds = np.average(np.array(overallF1ListProp), axis=0)
        TPPropMeanOverFolds = np.average(np.array(overallTPListProp), axis=0)
        TNPropMeanOverFolds = np.average(np.array(overallTNListProp), axis=0)
        FPPropMeanOverFolds = np.average(np.array(overallFPListProp), axis=0)
        FNPropMeanOverFolds = np.average(np.array(overallFNListProp), axis=0)
        convergencePropMeanOverFolds = np.mean(np.array(overallConvergenceEpochProp).astype(np.float))+1

        with open(resultsFilename, "a") as text_file:
            print(f"Training complete for all folds for params: {str(params)}", file=text_file)
            print(f"Total Training Time: {str(trainingTime)}", file=text_file)
            
            print(f"Mean Test Metrics Over All Folds:\n", file=text_file)
            print(f"All Spikes:\n", file=text_file)
            print(f"Final Accuracy: {str(accuracyAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final Precision: {str(precisionAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final Recall: {str(recallAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final F1: {str(F1AllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final TP: {str(TPAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final FP: {str(FPAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final TN: {str(TNAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final FN: {str(FNAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Obs. Until Max Accuracy: {str(convergenceAllMeanOverFolds)}\n", file=text_file)
            print(
                f"TP: {str(list(TPAllMeanOverFolds))}\nTN: {str(list(TNAllMeanOverFolds))}\nFP: {str(list(FPAllMeanOverFolds))}\nFN: {str(list(FNAllMeanOverFolds))}\n",
                file=text_file)
            print(
                f"Accuracy: {str(list(accuracyAllMeanOverFolds))}\nPrecision: {str(list(precisionAllMeanOverFolds))}\nRecall: {str(list(recallAllMeanOverFolds))}\nF1: {str(list(F1AllMeanOverFolds))}\n",
                file=text_file)

            print(f"Prop Spikes:\n", file=text_file)
            print(f"Final Accuracy: {str(accuracyPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final Precision: {str(precisionPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final Recall: {str(recallPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final F1: {str(F1PropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final TP: {str(TPPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final FP: {str(FPPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final TN: {str(TNPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final FN: {str(FNPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Obs. Until Max Accuracy: {str(convergencePropMeanOverFolds)}\n", file=text_file)
            print(
                f"TP: {str(list(TPPropMeanOverFolds))}\nTN: {str(list(TNPropMeanOverFolds))}\nFP: {str(list(FPPropMeanOverFolds))}\nFN: {str(list(FNPropMeanOverFolds))}\n",
                file=text_file)
            print(
                f"Accuracy: {str(list(accuracyPropMeanOverFolds))}\nPrecision: {str(list(precisionPropMeanOverFolds))}\nRecall: {str(list(recallPropMeanOverFolds))}\nF1: {str(list(F1PropMeanOverFolds))}\n",
                file=text_file)

Results file exists
Begin training for fold 0

n_train: 1
params: [0, 5, 5, 5, 0.1, 0.1, 255, 1.0, 0.1, 30, 10, 0]
starting training (new network): 
NewNetwork: NewNetwork(
  (X): Input()
  (Ae): DiehlAndCookNodes()
  (Ai): LIFNodes()
  (X_to_Ae): Connection(
    (source): Input()
    (target): DiehlAndCookNodes()
  )
  (Ae_to_Ai): Connection(
    (source): DiehlAndCookNodes()
    (target): LIFNodes()
  )
  (Ai_to_Ae): Connection(
    (source): LIFNodes()
    (target): DiehlAndCookNodes()
  )
)
epoch: 0

Begin training.

n_examples: 1
n_examples: 1
i: 0
update_interval: 1
Train progress: (0 / 1)
len(images): 1
(images).shape: torch.Size([1, 64, 255])
(labels): tensor([0])
current label: tensor(0)

All activity accuracy: 100.00 (last), 100.00 (average), 100.00 (best)
Proportion weighting accuracy: 100.00 (last), 100.00 (average), 100.00 (best)

Training complete.

saving model to: params/alco_snn/diehl_and_cook_2015/0_5_5_5_0.1_0.1_255_1.0_0.1_30_10_0.pt
Training complete for index: 0



KeyboardInterrupt: 

In [69]:
#FINAL MODEL (EVALUATE AFTER EVERY EPOCH)

n_train_full=dataset.get_data()[0]['X_train'].shape[0]
n_test_full=dataset.get_data()[0]['X_test'].shape[0]
size_of_train_set = n_train_full
train_set_update_interval = 1
resultsFilename = str(data)+"_results_data_ALL.txt"
if os.path.isfile(resultsFilename):
    print ("Results file exists")
else:
    print ("Results file doesn't exist, creating new file...")
starting_point=0
for i, params in enumerate(grid):
    if(i>=starting_point):
        overallPrecisionListAll = list()
        overallRecallListAll = list()
        overallAccuracyListAll = list()
        overallF1ListAll = list()
        overallTPListAll = list()
        overallTNListAll = list()
        overallFPListAll = list()
        overallFNListAll = list()
        overallConvergenceEpochAll = list()

        overallPrecisionListProp = list()
        overallRecallListProp = list()
        overallAccuracyListProp = list()
        overallF1ListProp = list()
        overallTPListProp = list()
        overallTNListProp = list()
        overallFPListProp = list()
        overallFNListProp = list()
        overallConvergenceEpochProp = list()


        trainingTimeList = list()
        with open(resultsFilename, "a") as text_file:
            print(f"\nResults for : {str(params)}\n\n", file=text_file)
        confusionMatricesList = list()
        for fold in np.arange(1):
            currentFold = fold
            # Train the network.
            with open(resultsFilename, "a") as text_file:
                print(f"\nFold : {str(currentFold)}\n", file=text_file)
                print(f"Size of Total Training Set : {str(n_train_full)}\n", file=text_file)
                print(f"Eval every {str(size_of_train_set)} observations.\n", file=text_file)
            print("Begin training for fold " + str(currentFold) + "\n")
            start = t()

            foldEpochAccuracyAll = list()
            foldEpochPrecisionAll = list()
            foldEpochRecallAll = list()
            foldEpochF1scoreAll = list()
            foldEpochTPAll = list()
            foldEpochTNAll = list()
            foldEpochFPAll = list()
            foldEpochFNAll = list()

            foldEpochAccuracyProp = list()
            foldEpochPrecisionProp = list()
            foldEpochRecallProp = list()
            foldEpochF1scoreProp = list()
            foldEpochTPProp = list()
            foldEpochTNProp = list()
            foldEpochFPProp = list()
            foldEpochFNProp = list()

            with open(resultsFilename, "a") as text_file:
                print(f"Current Training Index : {str(index)}/{str(int(n_train_full/size_of_train_set))}\n", file=text_file)
            dataset = EEGAlcoDatasetBalanced(testingByDatapoint=False)
            n_train=dataset.get_data()[0]['X_train'].shape[0]
            print('n_train: '+str(n_train))
            n_test=dataset.get_data()[0]['X_test'].shape[0]

            # Create a dataloader to iterate and batch data
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=gpu)
            result, trainingTime = createSNN(train=True, singleSample=False, n_neurons=params['n_neurons'],
                                             lr_pre=params['lr_pre'],
                                             inh=params['inh'], lr_post=params['lr_post'], exc=params['exc'],
                                             update_interval=train_set_update_interval, current_fold=currentFold,
                                             current_epoch=index)
            print("Training complete for index: " + str(index) + "\n")
            testResult, testingTime = createSNN(train=False, singleSample=False, n_neurons=params['n_neurons'],
                                                lr_pre=params['lr_pre'],
                                                inh=params['inh'], lr_post=params['lr_post'], exc=params['exc'],
                                                update_interval=params['update_interval'], current_fold=currentFold,
                                                current_epoch=index)

            TPAll = testResult['all'][0][0]
            FPAll = testResult['all'][0][1]
            FNAll = testResult['all'][1][0]
            TNAll = testResult['all'][1][1]
            TPProp = testResult['prop'][0][0]
            FPProp = testResult['prop'][0][1]
            FNProp = testResult['prop'][1][0]
            TNProp = testResult['prop'][1][1]

            precisionAll = (TPAll / (TPAll + FPAll))
            recallAll = (TPAll / (TPAll + FNAll))
            accuracyAll = (TPAll + TNAll) / (TPAll + FPAll + FNAll + TNAll)
            f1scoreAll = 2 * ((precisionAll * recallAll) / (precisionAll + recallAll))
            precisionProp = (TPProp / (TPProp + FPProp))
            recallProp = (TPProp / (TPProp + FNProp))
            accuracyProp = (TPProp + TNProp) / (TPProp + FPProp + FNProp + TNProp)
            f1scoreProp = 2 * ((precisionProp * recallProp) / (precisionProp + recallProp))

            foldEpochAccuracyAll.append(accuracyAll)
            foldEpochPrecisionAll.append(precisionAll)
            foldEpochRecallAll.append(recallAll)
            foldEpochF1scoreAll.append(f1scoreAll)
            foldEpochTPAll.append(TPAll)
            foldEpochTNAll.append(TNAll)
            foldEpochFPAll.append(FPAll)
            foldEpochFNAll.append(FNAll)

            foldEpochAccuracyProp.append(accuracyProp)
            foldEpochPrecisionProp.append(precisionProp)
            foldEpochRecallProp.append(recallProp)
            foldEpochF1scoreProp.append(f1scoreProp)
            foldEpochTPProp.append(TPProp)
            foldEpochTNProp.append(TNProp)
            foldEpochFPProp.append(FPProp)
            foldEpochFNProp.append(FNProp)

            convergenceObservationAccuracyAll=str(foldEpochAccuracyAll.index(max(foldEpochAccuracyAll)))
            convergenceObservationAccuracyProp=str(foldEpochAccuracyProp.index(max(foldEpochAccuracyProp)))

            with open(resultsFilename, "a") as text_file:
                print(f"Training complete for fold {str(currentFold)}\n", file=text_file)
                print(f"Number of Epochs: {str(len(foldEpochAccuracyAll))}\n", file=text_file)
                print(f"Fold {str(currentFold)} Test Metrics:\n", file=text_file)
                print(f"All Spikes:\n", file=text_file)
                print(f"TP: {str(foldEpochTPAll)}\nTN: {str(foldEpochTNAll)}\nFP: {str(foldEpochFPAll)}\nFN: {str(foldEpochFNAll)}\n", file=text_file)
                print(f"Accuracy: {str(foldEpochAccuracyAll)}\nPrecision: {str(foldEpochPrecisionAll)}\nRecall: {str(foldEpochRecallAll)}\nF1: {str(foldEpochF1scoreAll)}\n", file=text_file)
                print(f"Accuracy Maxes at Epoch: {str(convergenceObservationAccuracyAll)}\n", file=text_file)
                print(f"Prop Spikes:\n", file=text_file)
                print(f"TP: {str(foldEpochTPProp)}\nTN: {str(foldEpochTNProp)}\nFP: {str(foldEpochFPProp)}\nFN: {str(foldEpochFNProp)}\n",
                      file=text_file)
                print(f"Accuracy: {str(foldEpochAccuracyProp)}\nPrecision: {str(foldEpochPrecisionProp)}\nRecall: {str(foldEpochRecallProp)}\nF1: {str(foldEpochF1scoreProp)}\n", file=text_file)
                print(f"Accuracy Maxes at Epoch: {convergenceObservationAccuracyProp}\n", file=text_file)

            overallPrecisionListAll.append(foldEpochPrecisionAll)
            overallRecallListAll.append(foldEpochRecallAll)
            overallAccuracyListAll.append(foldEpochAccuracyAll)
            overallF1ListAll.append(foldEpochF1scoreAll)
            overallTPListAll.append(foldEpochTPAll)
            overallTNListAll.append(foldEpochTNAll)
            overallFPListAll.append(foldEpochFPAll)
            overallFNListAll.append(foldEpochFNAll)
            overallConvergenceEpochAll.append(convergenceObservationAccuracyAll)

            overallPrecisionListProp.append(foldEpochPrecisionProp)
            overallRecallListProp.append(foldEpochRecallProp)
            overallAccuracyListProp.append(foldEpochAccuracyProp)
            overallF1ListProp.append(foldEpochF1scoreProp)
            overallTPListProp.append(foldEpochTPProp)
            overallTNListProp.append(foldEpochTNProp)
            overallFPListProp.append(foldEpochFPProp)
            overallFNListProp.append(foldEpochFNProp)
            overallConvergenceEpochAll.append(convergenceObservationAccuracyProp)

        precisionAllMeanOverFolds=np.average(np.array(overallPrecisionListAll), axis=0)
        recallAllMeanOverFolds=np.average(np.array(overallRecallListAll), axis=0)
        accuracyAllMeanOverFolds=np.average(np.array(overallAccuracyListAll), axis=0)
        F1AllMeanOverFolds=np.average(np.array(overallF1ListAll), axis=0)
        TPAllMeanOverFolds = np.average(np.array(overallTPListAll), axis=0)
        TNAllMeanOverFolds = np.average(np.array(overallTNListAll), axis=0)
        FPAllMeanOverFolds = np.average(np.array(overallFPListAll), axis=0)
        FNAllMeanOverFolds = np.average(np.array(overallFNListAll), axis=0)
        convergenceAllMeanOverFolds = np.mean(np.array(overallConvergenceEpochAll).astype(np.float))+1

        precisionPropMeanOverFolds = np.average(np.array(overallPrecisionListProp), axis=0)
        recallPropMeanOverFolds = np.average(np.array(overallRecallListProp), axis=0)
        accuracyPropMeanOverFolds = np.average(np.array(overallAccuracyListProp), axis=0)
        F1PropMeanOverFolds = np.average(np.array(overallF1ListProp), axis=0)
        TPPropMeanOverFolds = np.average(np.array(overallTPListProp), axis=0)
        TNPropMeanOverFolds = np.average(np.array(overallTNListProp), axis=0)
        FPPropMeanOverFolds = np.average(np.array(overallFPListProp), axis=0)
        FNPropMeanOverFolds = np.average(np.array(overallFNListProp), axis=0)
        convergencePropMeanOverFolds = np.mean(np.array(overallConvergenceEpochProp).astype(np.float))+1

        with open(resultsFilename, "a") as text_file:
            print(f"Training complete for all folds for params: {str(params)}", file=text_file)
            print(f"Total Training Time: {str(trainingTime)}", file=text_file)
            
            print(f"Mean Test Metrics Over All Folds:\n", file=text_file)
            print(f"All Spikes:\n", file=text_file)
            print(f"Final Accuracy: {str(accuracyAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final Precision: {str(precisionAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final Recall: {str(recallAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final F1: {str(F1AllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final TP: {str(TPAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final FP: {str(FPAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final TN: {str(TNAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final FN: {str(FNAllMeanOverFolds[-1])}\n", file=text_file)
            print(f"Obs. Until Max Accuracy: {str(convergenceAllMeanOverFolds)}\n", file=text_file)
            print(
                f"TP: {str(list(TPAllMeanOverFolds))}\nTN: {str(list(TNAllMeanOverFolds))}\nFP: {str(list(FPAllMeanOverFolds))}\nFN: {str(list(FNAllMeanOverFolds))}\n",
                file=text_file)
            print(
                f"Accuracy: {str(list(accuracyAllMeanOverFolds))}\nPrecision: {str(list(precisionAllMeanOverFolds))}\nRecall: {str(list(recallAllMeanOverFolds))}\nF1: {str(list(F1AllMeanOverFolds))}\n",
                file=text_file)

            print(f"Prop Spikes:\n", file=text_file)
            print(f"Final Accuracy: {str(accuracyPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final Precision: {str(precisionPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final Recall: {str(recallPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final F1: {str(F1PropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final TP: {str(TPPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final FP: {str(FPPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final TN: {str(TNPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Final FN: {str(FNPropMeanOverFolds[-1])}\n", file=text_file)
            print(f"Obs. Until Max Accuracy: {str(convergencePropMeanOverFolds)}\n", file=text_file)
            print(
                f"TP: {str(list(TPPropMeanOverFolds))}\nTN: {str(list(TNPropMeanOverFolds))}\nFP: {str(list(FPPropMeanOverFolds))}\nFN: {str(list(FNPropMeanOverFolds))}\n",
                file=text_file)
            print(
                f"Accuracy: {str(list(accuracyPropMeanOverFolds))}\nPrecision: {str(list(precisionPropMeanOverFolds))}\nRecall: {str(list(recallPropMeanOverFolds))}\nF1: {str(list(F1PropMeanOverFolds))}\n",
                file=text_file)

Results file exists
Begin training for fold 0

n_train: 600
params: [0, 5, 5, 5, 0.1, 0.1, 255, 1.0, 0.1, 30, 10, 0]
starting training (new network): 
NewNetwork: NewNetwork(
  (X): Input()
  (Ae): DiehlAndCookNodes()
  (Ai): LIFNodes()
  (X_to_Ae): Connection(
    (source): Input()
    (target): DiehlAndCookNodes()
  )
  (Ae_to_Ai): Connection(
    (source): DiehlAndCookNodes()
    (target): LIFNodes()
  )
  (Ai_to_Ae): Connection(
    (source): LIFNodes()
    (target): DiehlAndCookNodes()
  )
)
epoch: 0

Begin training.

n_examples: 600
n_examples: 600
i: 0
update_interval: 1
Train progress: (0 / 600)
len(images): 600
(images).shape: torch.Size([600, 64, 255])
(labels): tensor([0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1,
        0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1,
        0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1,
        1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1,