In [21]:
from neuralflow import energy_model

In [23]:
class EnergyModel_DB(energy_model.EnergyModel):
    def fit_DB(self, data, optimizer, optimization, save={},initial_params=None):
        """Optimize a model from data.
        
        
        Parameters
        ----------
        data : dictionary
            Training and validation data: {'dataTR': data, 'dataCV': data or
            None}. entries.
            data is numpy array with dimensions (num_trial,2),
            dtype=np.ndarray. It contain spike data packed as numpy array of
            the size (num_trial,2), where each entry is a 1D array. N is the
            number of trials, and for each trial the first column contains
            inter spike intervals (ISIs) in seconds,
            and the second column contains the corresponding neuronal IDs
            (neurons are enumerated by integer IDs starting from 0, and trial
            termination is indicated by ID -1).
            data[i][0] - 1D array, ISIs of type self.dtype for the trial i. The
            first entry is time interval between trial start and the time of
            first spike. The last entry can be time interval between the last
            spike and trial termination time.
            data[i][1] - 1D array, neuronal IDs of type int64 for the trial i.
            The last entry is -1 if the trial termination time is recorded.
            Example: neuron 0 spiked at times 0.12, 0.15, 0.25, and neuron 1
            spiked at times 0.05, 0.2. Trial 0 started at t=0 and ended at
            t=0.28. In this case data[0][0]=np.array([0.05,0.07,0.03,0.05,0.05,
            0.03]), and data[0][1]=np.array([1,0,0,1,0,-1]).
        optimizer: ENUM('ADAM','GD','coordinate_descent')
            Optimization strategy.
        optimization: dictionary
            Optimions for optimization:
                gamma : dictionary
                    Specifies learning rates and parameters to optimize.
                    For ADAM, contains the following:
                        params_to_opt : list
                            List of parameters to optimze, can include
                            F,F0,Fr,C,fr,D.
                        alpha : float
                            ADAM learning rate
                        beta1 : float
                            ADAM weight for momentum
                        beta2 : float
                            ADAM weight for RMS prop momentum
                        epsilon : float
                            denominator added to RMS prop for numberical
                        stability.
                    For GD or coordinate_descent contains parameters to be
                    optimized as keys and the corresponding learning rates
                    as values (must contain at least one entry):
                        F : float
                            Learning rate for the driving force.
                        F0 : float
                            Learning rate for F0, which determines p0.
                        Fr : float
                            Learning rate for firing rates (optimized
                            through Fr).
                        C : float
                            Learning rate for C, which is the coefficient
                            that links Fr and fr via fr=C*Fr.
                        D : float
                            Learning rate for the diffusion coefficient
                max_iteration : int, optional
                    Maximum number of Epochs, terminate optimization after
                    reaching this limit. The default is 100.
                loglik_tol : float, optional
                    Early stopping criterion for relative improvement of
                    loglik. When relative change of loklik is less than
                    that, stop the optimization. The default is 0, in which
                    case optimization will not be terminated due to this
                    condition.
                etaf : float, optional
                    Regularization strength (set to zero for unregularized
                    optimization). The default is zero.
                sgd_options : dictionary, optional
                    Parameters of stochastic gradient descent:
                        mini_batch_number : int
                            Number of minibatches. Set to 1 for GD, set to
                            number of trials for full SGD.
                        nEVsolutions : int
                            Number of EV problem solutions per Epoch. Set
                            equal to mini_batch_number to solve EV on every
                            iteration.
                Fr_opt : dictionary, optional
                    A dictionary that controls line search of the constant
                    C. Availible options:
                        max_fun_eval : int
                            Maximum number of iterations for the search.
                        schedule : numpy array, dtype=int
                            Epoch numbers on which line search for C will
                            be performed.
                        nSearchPerEpoch : int
                            Number of line searches per epoch. Shoul not
                            exceed number of iterations per epoch.
                D_opt : dictionary, optional
                    A dictionary that controls line search of the diffusion
                    coefficient D. Entries are the same as in Fr_opt.
        save: dictionary, optional
            Options for saving the results:
                path : str
                    path for saving. The default is None.
                stride : int
                    Save files every stride number of iterations. The
                    default is max_iteration (only save the final file).
                schedule : numpy array, dtype=int
                    1D array with iteration numbers on which the fitted
                    parameters (e.g. peq and/or D) will be saved.
                    The default is np.arange(1-new_sim, max_iteration+1),
                    which includes all the iteration numbers, and saves
                    fitting results on each iteration.
                sim_start : int
                    Starting iteration number (0 if new simulation,
                    otherwise int >0). The default is 0.
           
        Returns
        -------
        em : self
            A fitted EnergyModel object.
        """
        
        # Check inputs
        self._check_optimization_options(data, optimizer, optimization, save)
        
        # Maximum Likelihood Estimation of the parameters
        if self.verbose:
            print("Performing Estimation of the model parameters...")
            print("The chosen optimizer is: " + str(optimizer))
        self.iterations_SGD_ = self._SGD_optimization(optimizer, data, optimization, save,initial_params)
    fit_DB = fit_DB