## Set up TensorFlow

Import TensorFlow into your program to get started:

In [53]:
import tensorflow as tf
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.8.0


In [54]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

## Build a machine learning model

Build a `tf.keras.Sequential` model by stacking layers.

In [55]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

In [56]:
predictions = model(x_train[:1]).numpy()
predictions

array([[-0.3521779 ,  0.35692912,  0.01328492,  0.24834192,  0.15572695,
        -0.64865565, -0.24591938,  0.68669105,  0.70681536, -0.21911447]],
      dtype=float32)

In [57]:
tf.nn.softmax(predictions).numpy()

array([[0.06000391, 0.1219385 , 0.08647649, 0.1093911 , 0.09971486,
        0.04460884, 0.06673092, 0.16957219, 0.17301928, 0.06854381]],
      dtype=float32)

In [58]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [59]:
loss_fn(y_train[:1], predictions).numpy()

3.1098232

In [74]:
opti = 'adam'

## Creating Custom Optimizer

skip these cells to try our code with basic adam optimizer

In [60]:
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.framework import ops
from tensorflow.python.training import optimizer
import tensorflow as tf
import numpy as np

In [64]:
# https://github.com/cilatpku/firework-algorithm/blob/master/fwa/BBFWA.py
class Firework(optimizer.Optimizer):
  def __init__(self,
                # params for prob
                evaluator = None,
                dim = 2,
                upper_bound = 100,
                lower_bound = -100,
                max_iter = 10000,
                max_eval = 20000,
                # params for method
                sp_size = 200,
                init_amp = 200, 
                name="Firework", use_locking=False, **kwargs):
    super(Firework, self).__init__(use_locking, name)

    ## Parameters

    # params of method
    self.sp_size = sp_size       # total spark size
    self.init_amp = init_amp     # initial dynamic amplitude

    # load params
    self.evaluator = evaluator
    self.dim = dim
    self.upper_bound = upper_bound
    self.lower_bound = lower_bound

    self.max_iter = max_iter
    self.max_eval = max_eval


    ## States

    # private init states
    self._num_iter = 0
    self._num_eval = 0
    self._dyn_amp = init_amp

    # public states
    self.best_idv = None    # best individual found
    self.best_fit = None    # best fitness found
    self.trace = []         # trace of best individual in each generation

    ## Fireworks
    self.fireworks = np.random.uniform(self.lower_bound, self.upper_bound, [1, self.dim])
    self.fireworks = self.fireworks.tolist()
    self.fits = self.evaluator(self.fireworks)

    ## Tensor versions of the constructor arguments, created in _prepare().
    self.dim_t = None
    self.upper_bound_t = None
    self.lower_bound_t = None
    self.max_iter_t = None
    self.max_eval_t = None
    self.sp_size_t = None
    self.init_amp_t = None

    self.fireworks_t = None
    self.fits_t = None

  def _create_slots(self, var_list):
    """For each model variable, create the optimizer variable associated with it.
        TensorFlow calls these optimizer variables "slots"."""
    # Create slots for the first and second moments.
    for v in var_list:
        self._zeros_slot(v, "fireworks", self._name)
    for v in var_list:
        self._zeros_slot(v, "fits", self._name)

  def _prepare(self):
    # self.evaluator_t = ops.convert_to_tensor(self.evaluator, name="evaloator")
    self.dim_t = ops.convert_to_tensor(self.dim, name="dimention")
    self.upper_bound_t = ops.convert_to_tensor(self.upper_bound, name="upper_bound")
    self.lower_bound_t = ops.convert_to_tensor(self.lower_bound, name="lower_bound")
    self.max_iter_t = ops.convert_to_tensor(self.max_iter, name="max_iterations")
    self.max_eval_t = ops.convert_to_tensor(self.max_eval, name="max_eval")
    self.sp_size_t = ops.convert_to_tensor(self.sp_size, name="sp_size")
    self.init_amp_t = ops.convert_to_tensor(self.init_amp, name="init_amp")

    self.fireworks_t = ops.convert_to_tensor(self.fireworks, name="fireworks")
    self.fits_t = ops.convert_to_tensor(self.fits, name="fits")

    print(self.fireworks_t)

  def _resource_apply_dense(self, grad, var):
    evaluator = self.evaluator
    dim_t = math_ops.cast(self.dim_t, var.dtype.base_dtype)
    upper_bound_t = math_ops.cast(self.upper_bound_t, var.dtype.base_dtype)
    lower_bound_t = math_ops.cast(self.lower_bound_t, var.dtype.base_dtype)
    max_iter_t = math_ops.cast(self.max_iter_t, var.dtype.base_dtype)
    max_eval_t = math_ops.cast(self.max_eval_t, var.dtype.base_dtype)
    sp_size_t = math_ops.cast(self.sp_size_t, var.dtype.base_dtype)
    init_amp_t = math_ops.cast(self.init_amp_t, var.dtype.base_dtype)

    fits = self.get_slot(grad, "fits")
    fireworks = self.get_slot(var, "fireworks")

    fireworks_update, fits_update = self.iter(self.fireworks, self.fits)
    
    self.fireworks = fireworks_update
    self.fits = fits_update

    fireworks_update_t = math_ops.cast(fireworks_update, var.dtype.base_dtype)
    fits_update_t = math_ops.cast(fits_update, var.dtype.base_dtype)

    self.fireworks_t = fireworks_update_t
    self.fits_t = fits_update_t
    
    print("fireworks_update : ", fireworks_update)
    print("fits_update : ", fits_update)

    #Create an op that groups multiple operations
    #When this op finishes, all ops in input have finished
    return control_flow_ops.group(*[fireworks_update_t, fits_update_t])

  ## Helper functions
  def iter(self, fireworks, fits):

      print("...\n")
      
      e_sparks, e_fits = self._explode(fireworks, fits)
        
      n_fireworks, n_fits = self._select(fireworks, fits, e_sparks, e_fits)    

      # update states
      if n_fits[0] < fits[0]:
          self._dyn_amp *= 1.2
      else:
          self._dyn_amp *= 0.9

      self._num_iter += 1
      self._num_eval += len(e_sparks)
          
      self.best_idv = n_fireworks[0]
      self.best_fit = n_fits[0]
      self.trace.append([n_fireworks[0], n_fits[0], self._dyn_amp])

      fireworks = n_fireworks
      fits = n_fits
      
      return fireworks, fits

  def _explode(self, fireworks, fits):
      
      bias = np.random.uniform(-self._dyn_amp, self._dyn_amp, [self.sp_size, self.dim])
      rand_samples = np.random.uniform(self.lower_bound, self.upper_bound, [self.sp_size, self.dim])
      e_sparks = fireworks + bias
      in_bound = (e_sparks > self.lower_bound) * (e_sparks < self.upper_bound)
      e_sparks = in_bound * e_sparks + (1 - in_bound) * rand_samples
      e_sparks = e_sparks.tolist()
      e_fits = self.evaluator(e_sparks)
      return e_sparks, e_fits    

  def _select(self, fireworks, fits, e_sparks, e_fits):
      idvs = fireworks + e_sparks
      fits = fits + e_fits
      idx = np.argmin(fits)
      return [idvs[idx]], [fits[idx]]


##################################################
##################################################

  def get_config(self):
      base_config = super().get_config()
      return {
          **base_config,
          "learning_rate": self._serialize_hyperparameter("learning_rate"),
          "decay": self._serialize_hyperparameter("decay"),
          "momentum": self._serialize_hyperparameter("momentum"),
      }

  def _apply_dense(self, grad, var):
      raise NotImplementedError("Dense gradient updates are not supported.")

  def _apply_sparse(self, grad, var):
      raise NotImplementedError("Sparse gradient updates are not supported.")

  def _resource_apply_sparse(self, grad, var):
      raise NotImplementedError("Sparse Resource gradient updates are not supported.")

In [77]:
# New Custom Function
# returns sum of squares of all integers in list
obj_func = lambda x: [sum([_ * _ for _ in xi]) for xi in x]
opti = Firework(evaluator=obj_func, dim=100, max_eval=10000)

## Compiling your model

In [78]:
model.compile(optimizer=opti,
              loss=loss_fn,
              metrics=['accuracy'])

## Train and evaluate your model

Use the `Model.fit` method to adjust your model parameters and minimize the loss: 

In [79]:
model.fit(x_train, y_train, epochs=15)

Epoch 1/15
Tensor("Firework/fireworks:0", shape=(1, 100), dtype=float32)
...

fireworks_update :  [[-94.36446767743391, -24.86005182145515, -14.634340523664136, -53.25310244797101, -26.785950828821257, -10.916962541477787, 76.1894402850952, 19.985818566057404, 79.05189038036458, -23.521711593113224, 28.872795001920295, 34.44101262411448, -82.72484279147935, -73.4921152610223, -61.62774412029073, 14.442095671969213, 7.898506950072587, 42.22452512970352, -8.957014933911069, 70.19348137082858, 89.22511161074661, -2.465497476063888, 86.70635533217902, -44.9138732598784, 82.25851285633007, 30.4697838062103, 33.96291108142168, 14.175950761490881, 73.38645458577801, -14.038183171278604, 19.175681293334648, -60.035222401943216, -10.52458886864369, 9.603111492761485, 72.87500224340826, -57.31690433700291, 0.6579867852005208, 21.127940135966966, -9.459401922300586, -34.370261840644616, 0.32509974564614197, 27.17527450230783, 54.53559667635173, -22.259498097834694, -99.74385473880996, 17.67812275

<keras.callbacks.History at 0x7f957caee410>

In [80]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 1s - loss: 0.0756 - accuracy: 0.9813 - 537ms/epoch - 2ms/step


[0.07561153918504715, 0.9812999963760376]

In [81]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [82]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[2.7883562e-10, 1.1842704e-14, 1.7307674e-11, 1.8010880e-06,
        4.4793260e-13, 1.6837262e-11, 9.6992611e-17, 9.9999797e-01,
        4.0200987e-10, 1.8609022e-07],
       [1.9432909e-15, 3.5976461e-10, 1.0000000e+00, 4.0839342e-11,
        6.4609330e-27, 1.2592099e-14, 6.8400125e-13, 4.1005563e-20,
        9.5747799e-11, 7.3193683e-23],
       [4.7800409e-12, 9.9994862e-01, 2.0295489e-07, 1.3834687e-09,
        5.5934738e-06, 6.9823325e-10, 4.6222792e-07, 4.3371187e-05,
        1.6790074e-06, 9.4725816e-12],
       [9.9999046e-01, 3.9611821e-16, 6.8503419e-07, 1.9962760e-09,
        1.4994784e-07, 1.5143118e-07, 7.6610486e-06, 6.1766130e-08,
        3.9004142e-10, 8.8872719e-07],
       [1.9450338e-08, 1.9749728e-13, 1.8847924e-06, 1.1021268e-09,
        9.9921465e-01, 1.5536166e-09, 1.8986385e-06, 1.0698554e-05,
        1.3608931e-07, 7.7063282e-04]], dtype=float32)>