# Install Importance Sampling module for Keras

In [6]:
# Install Dependency
!pip3 install blinker

# Clone Repo
!pip install --user keras-importance-sampling
!git clone https://github.com/idiap/importance-sampling.git
  
# Convert the repo into a module for usage in the notebook
!touch importance-sampling/__init__.py
!touch importance-sampling/importance_sampling/__init__.py
!touch importance-sampling/examples/__init__.py
!cp -r importance-sampling/* .

# Copy contents of repo for easy usage 
!cp -r importance-sampling/* .

fatal: destination path 'importance-sampling' already exists and is not an empty directory.


# Imports

In [105]:
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Activation
import numpy as np

import time

from keras import backend as K
from keras.callbacks import LearningRateScheduler, Callback
from keras.datasets import cifar10
from keras.layers import Activation, BatchNormalization, Conv2D, Dense, \
    GlobalAveragePooling2D, Input, add
from keras.models import Model
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
from keras.regularizers import l2
from keras.utils import to_categorical
from keras.models import load_model

from importance_sampling.datasets import CIFAR10, ZCAWhitening
from importance_sampling.models import wide_resnet
from importance_sampling.training import ImportanceTraining
from examples.example_utils import get_parser

import matplotlib.pyplot as plt

from IPython.display import Image, display

from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


# Wide Resnet Baseline on CIFAR 10

## Helper Classes

In [0]:
class TrainingSchedule(Callback):
    """Implement the training schedule for training a resnet on CIFAR10 for a
    given time budget."""
    def __init__(self, total_time):
        self._total_time = total_time
        self._lr = self._get_lr(0.0)

    def _get_lr(self, progress):
        if progress > 0.8:
            return 0.004
        elif progress > 0.5:
            return 0.02
        else:
            return 0.1

    def on_train_begin(self, logs={}):
        self._start = time.time()
        self._lr = self._get_lr(0.0)
        K.set_value(self.model.optimizer.lr, self._lr)

    def on_batch_end(self, batch, logs):
        t = time.time() - self._start

        if t >= self._total_time:
            self.model.stop_training = True

        lr = self._get_lr(t / self._total_time)
        if lr != self._lr:
            self._lr = lr
            K.set_value(self.model.optimizer.lr, self._lr)

    @property
    def lr(self):
        return self._lr

class Args():
  def __init__(self, depth, width, presample, batch_size, time_budget, dropout, useIS):
    self.depth = depth
    self.width = width
    self.presample = presample
    self.batch_size = batch_size
    self.time_budget = 3600*time_budget
    self.dropout = dropout
    self.importance_training = useIS

## Load Dataset

In [0]:
# Load Dataset
dset = ZCAWhitening(CIFAR10())

## Model Utility Functions

### Build Model

In [0]:
def build_model(args, dset):
  training_schedule = TrainingSchedule(args.time_budget)
  model = wide_resnet(args.depth, args.width, args.dropout)(dset.shape, dset.output_size)
  model.compile(
      loss="categorical_crossentropy",
      optimizer=SGD(lr=training_schedule.lr, momentum=0.9),
      metrics=["accuracy"]
  )
  model.summary()
  return model, training_schedule

### Train Model

In [0]:
def train_model(model, args, x_train, y_train, x_test, y_test, training_schedule):
  # Create the data augmentation generator
  datagen = ImageDataGenerator(
      # set input mean to 0 over the dataset
      featurewise_center=False,
      # set each sample mean to 0
      samplewise_center=False,
      # divide inputs by std of dataset
      featurewise_std_normalization=False,
      # divide each input by its std
      samplewise_std_normalization=False,
      # apply ZCA whitening
      zca_whitening=False,
      # randomly rotate images in the range (deg 0 to 180)
      rotation_range=0,
      # randomly shift images horizontally
      width_shift_range=0.1,
      # randomly shift images vertically
      height_shift_range=0.1,
      # randomly flip images
      horizontal_flip=True,
      # randomly flip images
      vertical_flip=False)
  datagen.fit(x_train)

  # Train the model
  if args.importance_training:
      history = ImportanceTraining(model).fit_generator(
          datagen.flow(x_train, y_train, batch_size=args.batch_size),
          validation_data=(x_test, y_test),
          epochs=10**6,
          verbose=1,
          callbacks=[training_schedule],
          batch_size=args.batch_size,
          steps_per_epoch=int(np.ceil(float(len(x_train)) / args.batch_size))
      )
  else:
      history = model.fit_generator(
          datagen.flow(x_train, y_train, batch_size=args.batch_size),
          validation_data=(x_test, y_test),
          epochs=10**6,
          verbose=1,
          callbacks=[training_schedule]
      )
  return model, history

### Evaluate Model

In [0]:
# Score trained model.
def eval_model(model, x_test, y_test):
  scores = model.evaluate(x_test, y_test, verbose=1)
  print('Test loss:', scores[0])
  print('Test accuracy:', scores[1])

### Wrapper 

In [0]:
def build_wide_res(args, dset):
  # Split dataset
  x_train, y_train = dset.train_data[:]
  x_test, y_test = dset.test_data[:]

  model, training_schedule = build_model(args, dset)
  model, history = train_model(model, args, x_train, y_train, x_test, y_test, training_schedule)
  return model, history

### Saving the model and plots to Drive

In [0]:
def save_model(model, file_name):
  path = "gdrive/My Drive/"
  model.save(path + file_name + '.h5')
  
def save_plots(model, history, file_name):
  path = "gdrive/My Drive" + file_name
  
  # summarize history for accuracy
  plt.plot(history.history['accuracy'][:-1])
  plt.plot(history.history['val_accuracy'][:-1])
  plt.title('model accuracy')
  plt.ylabel('accuracy')
  plt.xlabel('epoch')
  plt.legend(['train', 'test'], loc='upper left')
  plt.savefig(path + "-acc.png")
  plt.show()

  # summarize history for loss
  plt.plot(history.history['loss'][:-1])
  plt.plot(history.history['val_loss'][:-1])
  plt.title('model loss')
  plt.ylabel('loss')
  plt.xlabel('epoch')
  plt.legend(['train', 'test'], loc='upper left')
  plt.savefig(path + "-loss.png")
  plt.show()
  
def save(model, history, file_name):
  save_model(model, file_name)
  save_plots(model, histiry, file_name)

## Evaluation

### Wide ResNet 28-2 with 0.3 dropout and max time of 1 hr (with IS)

In [47]:
args = Args(depth = 28, width = 2, presample = 3.0, batch_size = 128, time_budget = 1, dropout = 0.3, useIS = True)
model, history = build_wide_res(args, dset)

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
conv2d_57 (Conv2D)              (None, 32, 32, 16)   432         input_5[0][0]                    
__________________________________________________________________________________________________
layer_normalization_51 (LayerNo (None, 32, 32, 16)   17          conv2d_57[0][0]                  
__________________________________________________________________________________________________
activation_53 (Activation)      (None, 32, 32, 16)   0           layer_normalization_51[0][0]

[NOTICE]: You are using BatchNormalization and/or Dropout.
Those layers may affect the importance calculations and you are advised to exchange them for LayerNormalization or BatchNormalization in test mode and L2 regularization.


Epoch 1/1000000
Epoch 2/1000000
Epoch 3/1000000
Epoch 4/1000000
Epoch 5/1000000
Epoch 6/1000000
Epoch 7/1000000
Epoch 8/1000000
Epoch 9/1000000
Epoch 10/1000000
Epoch 11/1000000
Epoch 12/1000000
Epoch 13/1000000
Epoch 14/1000000
Epoch 15/1000000
Epoch 16/1000000
Epoch 17/1000000
Epoch 18/1000000
Epoch 19/1000000
Epoch 20/1000000
Epoch 21/1000000
Epoch 22/1000000
Epoch 23/1000000
Epoch 24/1000000
Epoch 25/1000000
Epoch 26/1000000
Epoch 27/1000000
Epoch 28/1000000
Epoch 29/1000000
Epoch 30/1000000
Epoch 31/1000000
Epoch 32/1000000
Epoch 33/1000000
Epoch 34/1000000
Epoch 35/1000000
Epoch 36/1000000
Epoch 37/1000000
Epoch 38/1000000
Epoch 39/1000000
Epoch 40/1000000
Epoch 41/1000000
Epoch 42/1000000
Epoch 43/1000000
Epoch 44/1000000
Epoch 45/1000000
Epoch 46/1000000
Epoch 47/1000000
Epoch 48/1000000
Epoch 49/1000000
Epoch 50/1000000
Epoch 51/1000000
Epoch 52/1000000
Epoch 53/1000000
Epoch 54/1000000

In [50]:
# Test the model
x_test, y_test = dset.test_data[:]
eval_model(model, x_test, y_test)

Test loss: 0.507706770992279
Test accuracy: 0.9033


In [0]:
# Save model
save(model, history, 'res-net-28-2-1hr-dropout-IS')

In [106]:
# Load model and draw plots 
path = "gdrive/My Drive/"
file_name = "res-net-28-2-1hr-dropout-IS"
model = load_model(path + file_name + ".h5")

# Display Saved plots
listOfImageNames = [path + file_name + "-acc.png",
                    path + file_name + "-loss.png"]

for imageName in listOfImageNames:
    display(Image(filename=imageName))

ValueError: ignored

### Wide ResNet 28-2 with 0.3 dropout and max time of 1 hr (without IS)

In [0]:
args = Args(depth = 28, width = 2, presample = 3.0, batch_size = 128, time_budget = 1, dropout = 0.3, useIS = False)
model, history = build_wide_res(args, dset)

In [0]:
# Save model and Eval


### Wide ResNet 28-2 without dropout and max time of 1 hr (with IS)

In [0]:
args = Args(depth = 28, width = 2, presample = 3.0, batch_size = 128, time_budget = 1, dropout = 0, useIS = True)
model, history = build_wide_res(args, dset)

In [0]:
# Save model and Eval

### Wide ResNet 28-2 without dropout and max time of 1 hr (without IS)

In [0]:
args = Args(depth = 28, width = 2, presample = 3.0, batch_size = 128, time_budget = 1, dropout = 0, useIS = False)
model, history = build_wide_res(args, dset)