##### Copyright 2020 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License");

In [0]:
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# M-layer experiments
This notebook trains M-layers on the problems discussed in "Intelligent Matrix Exponentiation".



Running this locally, the  `m_layer` python module should come with the colab and should already be present.

The code of the `m_layer` python module can be downloaded from the google-research github repository.

In [0]:
import os.path
if os.path.isfile('m_layer.py'):
  from m_layer import MLayer
else:
  !if ! type "svn" > /dev/null; then sudo apt-get install subversion; fi
  !svn export https://github.com/google-research/google-research/trunk/m_layer
  from m_layer.m_layer import MLayer

In [0]:
GLOBAL_SEED = 1
import numpy as np
np.random.seed(GLOBAL_SEED)
import itertools
import functools
import operator
import logging
logging.getLogger('tensorflow').disabled = True

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from matplotlib import pylab 

print(tf.__version__)

# Generate a spiral and show extrapolation

In [0]:
SPIRAL_DIM_REP = 10
SPIRAL_DIM_MATRIX = 10
SPIRAL_LR = 1e-3
SPIRAL_EPOCHS = 200
SPIRAL_BATCH_SIZE = 16

def spiral_get_model():
  return tf.keras.models.Sequential(
      [tf.keras.layers.Dense(SPIRAL_DIM_REP,
                          input_shape=(2,)),
       MLayer(dim_m=SPIRAL_DIM_MATRIX),
       tf.keras.layers.ActivityRegularization(l2=1e-4),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(1, activation='sigmoid')]
  )


def spiral_generate(n_points, noise=0.5, rng=None):
  if rng is None:
    rng = np.random.RandomState()
  n = np.sqrt(0.001 + rng.rand(n_points,1)) * 1000 * (2 * np.pi) / 360.0
  x = 0.5 * (np.cos(n) * n + (2 * rng.rand(n_points, 1) - 1) * noise)
  y = 0.5 * (np.sin(n) * n + (2 * rng.rand(n_points, 1) - 1) * noise)
  return (np.vstack((np.hstack((x, y)), np.hstack((-x, -y)))),
          np.hstack((np.zeros(n_points), np.ones(n_points))))


def spiral_run(fig=None):
  if fig is None:
    fig = pylab.figure()
  x_train, y_train = spiral_generate(1000)

  model = spiral_get_model()
  model.summary()
  opt = tf.keras.optimizers.RMSprop(lr=SPIRAL_LR)
  model.compile(loss='binary_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
      monitor='loss', factor=0.2, patience=5, min_lr=1e-5)
  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=30,
                                 restore_best_weights=True)
  model.fit(x_train, y_train, epochs=SPIRAL_EPOCHS,
            batch_size=SPIRAL_BATCH_SIZE, verbose=2,
            callbacks=[reduce_lr, early_stopping])
  delta = 0.5 ** 3
  xs = np.arange(-10, 10.01, delta)
  ys = np.arange(-10, 10.01, delta)
  num_samples = len(xs)
  a = []
  for x in xs:
    for y in ys:
      a.append([x, y])
  t_nn_gen = model.predict(np.array(a))
  axes = fig.gca()
  XX, YY = np.meshgrid(xs, ys)
  axes.contourf(XX, YY, np.arcsinh(t_nn_gen.reshape(XX.shape)),
             levels=[0.5, 1.0],
             colors=[(0.89, 0.51, 0.41, 0.2), (0.41, 0.67, 0.81, 0.2)])
  axes.contour(XX, YY, np.arcsinh(t_nn_gen.reshape(XX.shape)),
             levels=[0.5])
  axes.grid()
  axes.plot(x_train[y_train==0, 1], x_train[y_train==0, 0], '.', 
            label='Class 1')
  axes.plot(x_train[y_train==1, 1], x_train[y_train==1, 0], '.', 
            label='Class 2')
  fig.show()

spiral_run()

# Train an M-layer on multivariate polynomials such as the determinant

In [0]:
POLY_BATCH_SIZE = 32
POLY_DIM_MATRIX = 8
POLY_DIM_INPUT_MATRIX = 3
POLY_EPOCHS = 150
POLY_SEED = 123
POLY_LOW = -1
POLY_HIGH = 1
POLY_NUM_SAMPLES = 8192
POLY_LR = 1e-3
POLY_DECAY = 1e-6

def poly_get_model():
  return tf.keras.models.Sequential(
      [tf.keras.layers.Flatten(input_shape=(POLY_DIM_INPUT_MATRIX,
                                         POLY_DIM_INPUT_MATRIX)),
       MLayer(dim_m=POLY_DIM_MATRIX, matrix_init='normal'),
       tf.keras.layers.ActivityRegularization(l2=1e-4),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(1)]
  )


def poly_fun(x, permanent=False):
  if permanent:
    return sum(
        functools.reduce(
            operator.mul,
            (x[i, pi] for i, pi in enumerate(perm)),
            1)
        for perm in itertools.permutations(range(x.shape[0])))
  return np.linalg.det(x)


def poly_run(permanent=False):
  rng = np.random.RandomState(seed=POLY_SEED)
  num_train = POLY_NUM_SAMPLES * 5 // 4
  x_train = rng.uniform(size=(num_train, POLY_DIM_INPUT_MATRIX,
                              POLY_DIM_INPUT_MATRIX), low=POLY_LOW,
                         high=POLY_HIGH)
  x_test = rng.uniform(size=(100000, POLY_DIM_INPUT_MATRIX,
                             POLY_DIM_INPUT_MATRIX), low=POLY_LOW,
                       high=POLY_HIGH)
  y_train = np.array([poly_fun(x, permanent=permanent) for x in x_train])
  y_test = np.array([poly_fun(x, permanent=permanent) for x in x_test])
  model = poly_get_model()
  model.summary()
  opt = tf.keras.optimizers.RMSprop(lr=POLY_LR, decay=POLY_DECAY)

  model.compile(loss='mse', optimizer=opt)
  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
      monitor='val_loss', factor=0.2, patience=5, min_lr=1e-5)
  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                                    patience=30,
                                                    restore_best_weights=True)
  model.fit(x_train, y_train, batch_size=POLY_BATCH_SIZE,
            epochs=POLY_EPOCHS,
            validation_split=0.2,
            shuffle=True,
            verbose=2,
            callbacks=[reduce_lr, early_stopping])
  score_train = model.evaluate(x=x_train, y=y_train)
  score_test = model.evaluate(x=x_test, y=y_test)

  print('Train, range %s - %s: %s' % (POLY_LOW, POLY_HIGH, score_train))
  print('Test, range %s - %s: %s' % (POLY_LOW, POLY_HIGH, score_test))

Permanents

In [0]:
poly_run(permanent=True)

Determinants

In [0]:
poly_run(permanent=False)

# Train an M-layer on CIFAR-10


In [0]:
CIFAR_DIM_REP = 35
CIFAR_DIM_MAT = 30
CIFAR_LR = 1e-3
CIFAR_DECAY = 1e-6
CIFAR_MOMENTUM = 0.9
CIFAR_BATCH_SIZE = 32
CIFAR_EPOCHS = 150
CIFAR_NAME = 'cifar10'
CIFAR_NUM_CLASSES = 10

def cifar_load_dataset():
  train = tfds.load(CIFAR_NAME, split='train', with_info=False, batch_size=-1)
  test = tfds.load(CIFAR_NAME, split='test', with_info=False, batch_size=-1)
  train_np = tfds.as_numpy(train)
  test_np = tfds.as_numpy(test)

  x_train, y_train = train_np['image'], train_np['label']
  x_test, y_test = test_np['image'], test_np['label']
  print('x_train shape:', x_train.shape)
  print(x_train.shape[0], 'train samples')
  print(x_test.shape[0], 'test samples')

  y_train = tf.keras.utils.to_categorical(y_train, CIFAR_NUM_CLASSES)
  y_test = tf.keras.utils.to_categorical(y_test, CIFAR_NUM_CLASSES)
  x_train_range01 = x_train.astype('float32') / 255
  x_test_range01 = x_test.astype('float32') / 255

  return (x_train_range01, y_train), (x_test_range01, y_test)

def cifar_get_model():
  return tf.keras.models.Sequential(
      [
       tf.keras.layers.Flatten(input_shape=(32, 32, 3)),
       tf.keras.layers.Dense(CIFAR_DIM_REP),
       MLayer(dim_m=CIFAR_DIM_MAT, with_bias=True),
       tf.keras.layers.ActivityRegularization(1e-4),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(CIFAR_NUM_CLASSES, activation='softmax')
       ])

def cifar_run():
  (x_train, y_train), (x_test, y_test) = cifar_load_dataset()
  model = cifar_get_model()
  model.summary()
  opt = tf.keras.optimizers.SGD(lr=CIFAR_LR, momentum=CIFAR_MOMENTUM,
                                   decay=CIFAR_DECAY)

  model.compile(loss='categorical_crossentropy', optimizer=opt,
                metrics=['accuracy'])

  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
      monitor='val_acc', factor=0.2, patience=5, min_lr=1e-5)
  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_acc', 
                                                    patience=15,
                                                    restore_best_weights=True)

  history = model.fit(
      x_train,
      y_train,
      batch_size=CIFAR_BATCH_SIZE,
      epochs=CIFAR_EPOCHS,
      validation_split=0.1,
      shuffle=True,
      verbose=2,
      callbacks=[reduce_lr, early_stopping])

  scores = model.evaluate(x_test, y_test, verbose=0)
  print('Test loss:', scores[0])
  print('Test accuracy:', scores[1])

cifar_run()

# Train an M-layer on periodic functions

In [0]:
PERIODIC_BATCH_SIZE = 128
PERIODIC_DIM_MATRIX = 6
PERIODIC_EPOCHS = 300
PERIODIC_SEED = 123
PERIODIC_LR = 5e-3
PERIODIC_DECAY = 5e-6

def periodic_matrix_init(shape, rng=None, **kwargs):
  if rng is None:
    rng = np.random.RandomState()
  data = np.float32(rng.normal(loc=0.0, scale=0.01, size=shape))
  for i in range(shape[1]):
    data[:, i, i] -= 10.0
  return data

def periodic_get_model(rng=None):
  if rng is None:
    rng = np.random.RandomState()
  return tf.keras.models.Sequential([
      tf.keras.layers.Dense(
          2,
          input_shape=(2,),
          kernel_initializer=tf.keras.initializers.RandomNormal(mean=0.1)),
      MLayer(PERIODIC_DIM_MATRIX, with_bias=False,
             matrix_init=lambda shape, **kwargs: 
             periodic_matrix_init(shape, rng=rng, **kwargs)),
      tf.keras.layers.ActivityRegularization(l2=1e-4),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(1)
  ])

def periodic_dist2(y_true, y_pred):
  return tf.nn.l2_loss(y_true - y_pred)

def periodic_run(get_model, basename, load_objects, epochs, batch_size, lr,
                 decay):
  rng = np.random.RandomState(seed=PERIODIC_SEED)
  for freq1 in range(3, 10):
    for freq2 in range(freq1 + 1, 10):
      phase1 = rng.uniform(high=np.pi / 3)
      phase2 = rng.uniform(high=np.pi / 3)
      coeff1 = rng.uniform(low=1.0, high=2.0)
      coeff2 = rng.uniform(low=5.0, high=10.0)
      num_samples = 2 * 10**5
      x_train = np.stack([
          np.ones(num_samples + 1),
          np.linspace(start=0.0, stop=2.0, num=num_samples + 1)
      ],
                            axis=1)
      x_test = np.stack([
          np.ones(10 * num_samples + 1),
          np.linspace(start=0.0, stop=20.0, num=10 * num_samples + 1)
      ],
                           axis=1)
      y_train = (
          coeff1 * np.cos(2 * freq1 * np.pi * x_train[:, 1] + phase1) +
          coeff2 * np.cos(2 * freq2 * np.pi * x_train[:, 1] + phase2) +
          rng.normal(scale=1e-4, size=num_samples + 1))

      predictions_test = []
      for rep in range(3):
          model_to_train = get_model(rng=rng)

          input1 = tf.keras.layers.Input(shape=(2,))
          output1 = model_to_train(input1)
          input_plus_6 = tf.keras.layers.Lambda(
              lambda x: 2 * x + [-1, 6])(input1)
          output2_premult = model_to_train(input_plus_6)
          output2 = tf.keras.layers.Lambda(lambda x: tf.math.maximum(
              tf.constant(0, dtype=tf.dtypes.float32),
              tf.math.abs(x) - 100))(
                  output2_premult)
          output = tf.keras.layers.Concatenate(axis=1)([output1, output2])
          model = tf.keras.models.Model(inputs=[input1], outputs=output)

          opt = tf.keras.optimizers.RMSprop(lr=lr, decay=decay)

          model.compile(
              loss='mean_squared_error', optimizer=opt,
               metrics=[periodic_dist2])

          model.fit(
              x_train,
              np.pad([y_train], [(0, 1), (0, 0)], mode='constant').T,
              batch_size=batch_size,
              epochs=epochs,
              shuffle=True,
              verbose=2)

          predictions_test.append(model.predict(x_test)[:, 0])
          print('Run %d done' % rep)
      fig = plt.figure()
      axes = fig.gca()
      axes.set_title('%.2f cos (%dx+%.2f) + %.2f cos (%dx+%.2f) (%s)' %
                (coeff1, freq1, phase1, coeff2, freq2, phase2, basename))


      axes.set_ylim([-15, 15])
      axes.set_xlim([0, 7])
      y_test = (
          coeff1 * np.cos(2 * freq1 * np.pi * x_test[:, 1] + phase1) +
          coeff2 * np.cos(2 * freq2 * np.pi * x_test[:, 1] + phase2))
      axes.plot(x_test[:, 1], y_test, 'g', label='Target')
      styles = ['k:', 'k-', 'k--']
      for pred, style in zip(predictions_test, styles):
        plt.plot(x_test[:, 1], pred, style)
      axes.grid()
      axes.axvspan(0, 2, facecolor='b', alpha=0.2)
      axes.set_xlabel('input')
      axes.set_ylabel('output')
      plt.show()
      fig.show()

periodic_run(periodic_get_model,
      'M-Layer', {'MLayer': MLayer},
      PERIODIC_EPOCHS,
      PERIODIC_BATCH_SIZE,
      lr=PERIODIC_LR,
      decay=PERIODIC_DECAY)