In [None]:
# Import the tensorflow
import tensorflow as tf
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.9.2


In [None]:
# Build the model
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

In [None]:
# Instantiate a loss function.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the training dataset.
# import numpy as np
batch_size = 64
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# x_train = np.reshape(x_train, (-1, 784))
# x_test = np.reshape(x_test, (-1, 784))

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# batched_train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

# Define metric for the validation
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
import numpy as np
import copy as cp

def get_model_weights(model):
  w = []
  for l in model.layers:
    w.append(cp.deepcopy(l.get_weights()))
  return w
  
def set_model_weights(model, weights):
  for i, l in enumerate(model.layers):
    l.set_weights(weights[i])

def compute_sgd(x_train, y_train, loss_fn, w):
  with tf.GradientTape() as tape:
      logits = model(x_train, training=True)
      loss_value = loss_fn(y_train, logits)
  grads = tape.gradient(loss_value, w)
  return grads, loss_value

def spider_boost_training(model, loss_fn, x_train, y_train, q, batch_size = 64, epochs=50):
  
  lipshitz_const = 200  # np.linalg.norm(tx, 'fro') ** 2
  learning_rate = 1/(2*lipshitz_const)
  optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)

  grads_h = []
  
  # w = get_model_weights(model)
  # w_h = [w] 
  w_h = [model.trainable_weights]

  n = len(y_train)

  for epoch in range(epochs):
      if epoch%10 == 0:
        print("\nStart of epoch %d" % (epoch,))
      # Compute full GD each q'th epoch
      if epoch % q == 0:
        grads, loss_value = compute_sgd(x_train, y_train, loss_fn, model.trainable_weights)
      # Otherwise make calculate SGD
      else:
        print(w_h[epoch] is w_h[epoch-1])
        idxs = np.random.choice(n, batch_size, replace=True)
        x_train_batch = x_train[idxs]
        y_train_batch = y_train[idxs]

        # grads, loss_value = compute_sgd(x_train_batch, y_train_batch, loss_fn, model.trainable_weights)
        # set_model_weights(model, w_h[epoch-1])
        # prev_grads, _ = compute_sgd(x_train_batch, y_train_batch, loss_fn, model.trainable_weights)
        grads, loss_value = compute_sgd(x_train_batch, y_train_batch, loss_fn, model.trainable_weights)
        prev_grads, _ = compute_sgd(x_train_batch, y_train_batch, loss_fn, w_h[epoch-1])

        grads = [(grad - prev_grad + grad_h)/len(idxs) for grad, prev_grad, grad_h in zip(grads, prev_grads, grads_h[epoch-1])]

      # set_model_weights(model, w_h[epoch])

      optimizer.apply_gradients(zip(grads, model.trainable_weights))
      grads_h.append(grads)

      # w = get_model_weights(model)
      # w_h.append(w)
      w_h.append(model.trainable_weights)
      if epoch%10 == 0:
        print(
            "Training loss at epoch %d: %.4f"
            % (epoch, float(loss_value))
        )

In [None]:
spider_boost_training(model, loss_fn, x_train, y_train, 10)


Start of epoch 0
Training loss at epoch 0: 186.7522

Start of epoch 1
False
Training loss at epoch 1: 55.0638

Start of epoch 2
False
Training loss at epoch 2: 48.4164

Start of epoch 3
False
Training loss at epoch 3: 77.9045

Start of epoch 4
False
Training loss at epoch 4: 58.0486

Start of epoch 5
False
Training loss at epoch 5: 70.5925

Start of epoch 6
False
Training loss at epoch 6: 72.3616

Start of epoch 7
False
Training loss at epoch 7: 54.6220

Start of epoch 8
False
Training loss at epoch 8: 54.3897

Start of epoch 9
False
Training loss at epoch 9: 67.2313

Start of epoch 10
Training loss at epoch 10: 63.4636

Start of epoch 11
False
Training loss at epoch 11: 40.2330

Start of epoch 12
False
Training loss at epoch 12: 41.4355

Start of epoch 13
False
Training loss at epoch 13: 46.4037

Start of epoch 14
False
Training loss at epoch 14: 30.1770

Start of epoch 15
False
Training loss at epoch 15: 33.1872

Start of epoch 16
False
Training loss at epoch 16: 35.1767

Start of e

In [None]:
def validate_model(val_dataset, model, metric):
  # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        metric.update_state(y_batch_val, val_logits)
    val_acc = metric.result()
    metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))

In [None]:
validate_model(val_dataset, model, val_acc_metric)

Validation acc: 0.2678
