##### Copyright 2020 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# M-layer experiments
This notebook trains M-layers on the problems discussed in "Intelligent Matrix Exponentiation".



Running this locally, the  `m_layer` python module should come with the colab and should already be present.

The code of the `m_layer` python module can be downloaded from the google-research github repository.

In [None]:
import os.path
if os.path.isfile('m_layer.py'):
  from m_layer import MLayer
else:
  !if ! type "svn" > /dev/null; then sudo apt-get install subversion; fi
  !svn export https://github.com/google-research/google-research/trunk/m_layer
  from m_layer.m_layer import MLayer

In [None]:
GLOBAL_SEED = 1
import numpy as np
np.random.seed(GLOBAL_SEED)
import itertools
import functools
import operator
import logging
logging.getLogger('tensorflow').disabled = True

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from matplotlib import pylab

print(tf.__version__)
print(tf.config.experimental.list_physical_devices('GPU'))

# Generate a spiral and show extrapolation

In [None]:
SPIRAL_DIM_REP = 10
SPIRAL_DIM_MATRIX = 10
SPIRAL_LAYER_SIZE = 20
SPIRAL_LR = 0.01
SPIRAL_EPOCHS = 1000
SPIRAL_BATCH_SIZE = 16

def spiral_m_layer_model():
  return tf.keras.models.Sequential(
      [tf.keras.layers.Dense(SPIRAL_DIM_REP,
                          input_shape=(2,)),
       MLayer(dim_m=SPIRAL_DIM_MATRIX, 
              with_bias=True, 
              matrix_squarings_exp=None,
              matrix_init='normal'),
       tf.keras.layers.ActivityRegularization(l2=1e-3),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(1, activation='sigmoid')]
  )

def spiral_dnn_model(activation_type):
  return tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(2,)),
      tf.keras.layers.Dense(SPIRAL_LAYER_SIZE,
                            activation=activation_type),
      tf.keras.layers.Dense(SPIRAL_LAYER_SIZE,
                            activation=activation_type),
      tf.keras.layers.Dense(1, activation='sigmoid'),
   ])

def spiral_generate(n_points, noise=0.5, rng=None, extra_rotation=False):
  if rng is None:
    rng = np.random.RandomState()
  if not extra_rotation:
    n = np.sqrt(0.001 + (.25)*rng.rand(n_points, 1)) * 6 * (2 * np.pi)
  else:
    n = np.sqrt((7.0/36)*rng.rand(n_points, 1)+.25) * 6 * (2 * np.pi)
  x = 0.5 * (np.sin(n) * n + (2 * rng.rand(n_points, 1) - 1) * noise)
  y = 0.5 * (np.cos(n) * n + (2 * rng.rand(n_points, 1) - 1) * noise)
  return (np.vstack((np.hstack((x, y)), np.hstack((-x, -y)))),
          np.hstack((np.zeros(n_points), np.ones(n_points))))


def spiral_run(model_type, fig=None, activation_type=None, ):
  if fig is None:
    fig = pylab.figure(figsize=(8,8), dpi=144)
  model = spiral_dnn_model(activation_type) if model_type=="dnn" else\
          spiral_m_layer_model()
  x_train, y_train = spiral_generate(1000)
  x_test, y_test =  spiral_generate(333, extra_rotation=True)
  model.summary()
  opt = tf.keras.optimizers.RMSprop(lr=SPIRAL_LR)
  model.compile(loss='binary_crossentropy',
                optimizer=opt,
                metrics=['accuracy'])
  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
      monitor='loss', factor=0.2, patience=5, min_lr=1e-5)
  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='loss',
                                                    patience=30,
                                                    min_delta=0.0001,
                                                    restore_best_weights=True)
  result = model.fit(x_train, y_train, epochs=SPIRAL_EPOCHS,
            batch_size=SPIRAL_BATCH_SIZE, verbose=2,
            callbacks=[reduce_lr, early_stopping])
  n_epochs = len(result.history['loss'])
  delta = 0.5 ** 3
  xs = np.arange(-14, 14.01, delta)
  ys = np.arange(-14, 14.01, delta)
  num_samples = len(xs)
  a = []
  for x in xs:
    for y in ys:
      a.append([x, y])
  t_nn_gen = model.predict(np.array(a))
  axes = fig.gca()
  XX, YY = np.meshgrid(xs, ys)
  axes.contourf(XX, YY, np.arcsinh(t_nn_gen.reshape(XX.shape)),
             levels=[0.0, 0.5, 1.0],
             colors=[(0.41, 0.67, 0.81, 0.2), (0.89, 0.51, 0.41, 0.2)])
  axes.contour(XX, YY, np.arcsinh(t_nn_gen.reshape(XX.shape)),
             levels=[0.5])
  axes.set_aspect(1)
  axes.grid()
  axes.plot(x_train[y_train==0, 1], x_train[y_train==0, 0], '.', ms = 2,
            label='Class 1')
  axes.plot(x_train[y_train==1, 1], x_train[y_train==1, 0], '.', ms = 2,
            label='Class 2')
  plt.plot(x_test[y_test==1, 1], x_test[y_test==1, 0], '.', ms = .5,
            label='Class 2')
  plt.plot(x_test[y_test==0, 1], x_test[y_test==0, 0], '.', ms = .5,
            label='Class 1')

  return fig, n_epochs, result.history['loss'][-1]

fig, n_epochs, loss = spiral_run('m_layer')

# Train an M-layer on multivariate polynomials such as the determinant

In [None]:
POLY_BATCH_SIZE = 32
POLY_DIM_MATRIX = 8
POLY_DIM_INPUT_MATRIX = 3
POLY_EPOCHS = 150
POLY_SEED = 123
POLY_LOW = -1
POLY_HIGH = 1
POLY_NUM_SAMPLES = 8192
POLY_LR = 1e-3
POLY_DECAY = 1e-6

def poly_get_model():
  return tf.keras.models.Sequential(
      [tf.keras.layers.Flatten(input_shape=(POLY_DIM_INPUT_MATRIX,
                                         POLY_DIM_INPUT_MATRIX)),
       MLayer(dim_m=POLY_DIM_MATRIX, matrix_init='normal'),
       tf.keras.layers.ActivityRegularization(l2=1e-4),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(1)]
  )


def poly_fun(x, permanent=False):
  if permanent:
    return sum(
        functools.reduce(
            operator.mul,
            (x[i, pi] for i, pi in enumerate(perm)),
            1)
        for perm in itertools.permutations(range(x.shape[0])))
  return np.linalg.det(x)


def poly_run(permanent=False):
  rng = np.random.RandomState(seed=POLY_SEED)
  num_train = POLY_NUM_SAMPLES * 5 // 4
  x_train = rng.uniform(size=(num_train, POLY_DIM_INPUT_MATRIX,
                              POLY_DIM_INPUT_MATRIX), low=POLY_LOW,
                         high=POLY_HIGH)
  x_test = rng.uniform(size=(100000, POLY_DIM_INPUT_MATRIX,
                             POLY_DIM_INPUT_MATRIX), low=POLY_LOW,
                       high=POLY_HIGH)
  y_train = np.array([poly_fun(x, permanent=permanent) for x in x_train])
  y_test = np.array([poly_fun(x, permanent=permanent) for x in x_test])
  model = poly_get_model()
  model.summary()
  opt = tf.keras.optimizers.RMSprop(lr=POLY_LR, decay=POLY_DECAY)

  model.compile(loss='mse', optimizer=opt)
  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
      monitor='val_loss', factor=0.2, patience=5, min_lr=1e-5)
  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', 
                                                    patience=30,
                                                    restore_best_weights=True)
  model.fit(x_train, y_train, batch_size=POLY_BATCH_SIZE,
            epochs=POLY_EPOCHS,
            validation_split=0.2,
            shuffle=True,
            verbose=2,
            callbacks=[reduce_lr, early_stopping])
  score_train = model.evaluate(x=x_train, y=y_train)
  score_test = model.evaluate(x=x_test, y=y_test)

  print('Train, range %s - %s: %s' % (POLY_LOW, POLY_HIGH, score_train))
  print('Test, range %s - %s: %s' % (POLY_LOW, POLY_HIGH, score_test))

Permanents

In [None]:
poly_run(permanent=True)

Determinants

In [None]:
poly_run(permanent=False)

# Train an M-layer on periodic data

In [None]:
PERIODIC_EPOCHS = 1000
PERIODIC_BATCH_SIZE = 128
PERIODIC_LR = 0.00001
PERIODIC_DIM_MATRIX = 10
PERIODIC_INIT_SCALE = 0.01
PERIODIC_DIAG_INIT = 10
PERIODIC_SEED = 123

def periodic_matrix_init(shape, rng=None, **kwargs):
  if rng is None:
    rng = np.random.RandomState()
  data = np.float32(rng.normal(loc=0, scale=PERIODIC_INIT_SCALE, size=shape))
  for i in range(shape[1]):
    data[:, i, i] -= PERIODIC_DIAG_INIT
  return data

def periodic_get_model(rng=None):
  if rng is None:
    rng = np.random.RandomState()
  return tf.keras.models.Sequential([
      tf.keras.layers.Dense(
          2, input_shape=(1,),
          kernel_initializer=tf.keras.initializers.RandomNormal()),
      MLayer(PERIODIC_DIM_MATRIX, with_bias=True, matrix_squarings_exp=None,
             matrix_init=lambda shape, **kwargs:
             periodic_matrix_init(shape, rng=rng, **kwargs)),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(1)
  ])


def periodic_dist2(y_true, y_pred):
  return tf.nn.l2_loss(y_true - y_pred)

def periodic_run(get_model):
  rng = np.random.RandomState(seed=PERIODIC_SEED)
  # See README file for information about this dataset.
  with gfile.Open('daily-min-temperatures.csv', 'r') as f:
    data = pd.read_csv(f)
  dates = data['Date']
  y = data['Temp']
  temperatures = data['Temp']
  y = list(np.convolve(temperatures - np.mean(temperatures), np.full(7, 1 / 7),
                       mode='valid'))
  num_train = 9 * len(y) // 10
  num_test = len(y) - num_train
  x_all = np.arange(len(y)).tolist()
  x_train = x_all[:num_train]
  y_train = y[:num_train]
  x_test = x_all[num_train:]
  y_targets = y[num_train:]

  model_to_train = get_model(rng=rng)
  input = tf.keras.layers.Input(shape=(1,))
  output = model_to_train(input)
  model = tf.keras.models.Model(inputs=input, outputs=output)

  opt = tf.keras.optimizers.RMSprop(lr=PERIODIC_LR, decay=0)
  early_stopping = tf.keras.callbacks.EarlyStopping(restore_best_weights=True)
  model.compile(
      loss='mean_squared_error', optimizer=opt,
        metrics=[periodic_dist2])
  history = model.fit(x_train, y_train,
                      batch_size=PERIODIC_BATCH_SIZE, epochs=PERIODIC_EPOCHS,
                      shuffle=True, verbose=1, callbacks=[early_stopping])
  y_predictions = model.predict(x_all)

  plt.plot(x_train, y_train, linewidth=1, alpha=0.7)
  plt.plot(x_test, y_targets, linewidth=1, alpha=0.7)
  plt.plot(x_all, y_predictions, color='magenta')
  plt.legend(['y_train', 'y_targets', 'y_predictions'])
  plt.xlim([0, 3650])
  plt.ylabel('Temperature (Celsius)')
  plt.grid(True, which='major', axis='both')
  plt.grid(True, which='minor', axis='both')
  xtick_index = [i for i, date in enumerate(dates) if date.endswith('-01-01')]
  plt.xticks(ticks=xtick_index,
             labels=[x[:4] for x in dates[xtick_index].to_list()],
             rotation=30)
  plt.show()

periodic_run(periodic_get_model)

# Train an M-layer on CIFAR-10


In [None]:
CIFAR_DIM_REP = 35
CIFAR_DIM_MAT = 30
CIFAR_LR = 1e-3
CIFAR_DECAY = 1e-6
CIFAR_MOMENTUM = 0.9
CIFAR_BATCH_SIZE = 32
CIFAR_EPOCHS = 150
CIFAR_NAME = 'cifar10'
CIFAR_NUM_CLASSES = 10

def cifar_load_dataset():
  train = tfds.load(CIFAR_NAME, split='train', with_info=False, batch_size=-1)
  test = tfds.load(CIFAR_NAME, split='test', with_info=False, batch_size=-1)
  train_np = tfds.as_numpy(train)
  test_np = tfds.as_numpy(test)

  x_train, y_train = train_np['image'], train_np['label']
  x_test, y_test = test_np['image'], test_np['label']
  print('x_train shape:', x_train.shape)
  print(x_train.shape[0], 'train samples')
  print(x_test.shape[0], 'test samples')

  y_train = tf.keras.utils.to_categorical(y_train, CIFAR_NUM_CLASSES)
  y_test = tf.keras.utils.to_categorical(y_test, CIFAR_NUM_CLASSES)
  x_train_range01 = x_train.astype('float32') / 255
  x_test_range01 = x_test.astype('float32') / 255

  return (x_train_range01, y_train), (x_test_range01, y_test)

def cifar_get_model():
  return tf.keras.models.Sequential(
      [
       tf.keras.layers.Flatten(input_shape=(32, 32, 3)),
       tf.keras.layers.Dense(CIFAR_DIM_REP),
       MLayer(dim_m=CIFAR_DIM_MAT, with_bias=True, matrix_squarings_exp=3),
       tf.keras.layers.ActivityRegularization(1e-3),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(CIFAR_NUM_CLASSES, activation='softmax')
       ])

def cifar_run():
  (x_train, y_train), (x_test, y_test) = cifar_load_dataset()
  model = cifar_get_model()
  model.summary()
  opt = tf.keras.optimizers.SGD(lr=CIFAR_LR, momentum=CIFAR_MOMENTUM,
                                   decay=CIFAR_DECAY)

  model.compile(loss='categorical_crossentropy', optimizer=opt,
                metrics=['accuracy'])

  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
      monitor='val_acc', factor=0.2, patience=5, min_lr=1e-5)
  early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_acc', 
                                                    patience=15,
                                                    restore_best_weights=True)

  history = model.fit(
      x_train,
      y_train,
      batch_size=CIFAR_BATCH_SIZE,
      epochs=CIFAR_EPOCHS,
      validation_split=0.1,
      shuffle=True,
      verbose=2,
      callbacks=[reduce_lr, early_stopping])

  scores = model.evaluate(x_test, y_test, verbose=0)
  print('Test loss:', scores[0])
  print('Test accuracy:', scores[1])

cifar_run()