In [1]:
"""Class that represents the network to be evolved."""
import random
import logging
from train import train_and_score
from inference import predict

INFO:tensorflow:Enabling eager execution
INFO:tensorflow:Enabling v2 tensorshape
INFO:tensorflow:Enabling resource variables
INFO:tensorflow:Enabling tensor equality
INFO:tensorflow:Enabling control flow v2



In [2]:
class Network():
    """Represent a network and let us operate on it.

    Currently only works for an MLP.
    """

    def __init__(self, nn_param_choices=None):
        """Initialize our network.

        Args:
            nn_param_choices (dict): Parameters for the network, includes:
                'window_size':[i for i in range(1,50)]
                'nb_neurons': [i for i in range(3, 41, 1)],
                'nb_layers': [i for i in range(1,11)],
                'batch_size':[i for i in range(1,21)],
                'epoch':[i for i in range(10,501)],
                'optimizer': ['rmsprop', 'adam', 'sgd', 'adagrad',
                                  'adadelta', 'adamax', 'nadam','ftrl'],
        """
        self.accuracy = 0.
        self.nn_param_choices = nn_param_choices
        self.network = {}  # (dic): represents MLP network parameters
        self.mdoel = None

    def create_random(self):
        """Create a random network."""
        for key in self.nn_param_choices:
            self.network[key] = random.choice(self.nn_param_choices[key])

    def create_set(self, network):
        """Set network properties.

        Args:
            network (dict): The network parameters

        """
        self.network = network

    def set_model(self, model):
        self.model = model
    
    def train(self):
        """Train the network and record the accuracy.

        Args:
            dataset (str): Name of dataset to use.

        """
        if self.accuracy == 0.:
            self.accuracy = train_and_score(self)

    def inference(self, x, y):
        return predict(x, y, self.model)
    
    def print_network(self):
        """Print out a network."""
        logging.info(self.network)
        logging.info("Network accuracy: %.2f%%" % (self.accuracy * 100))