In [None]:
from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.models import Model, Sequential
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D, Input, Dense
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
from keras.initializers import glorot_uniform, RandomNormal
import random, os, pickle, copy
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
%matplotlib inline
savedir = os.path.join('./save/transfer_learning_cifar5_multiple_instantiations_small')
fig_dir = os.path.join(os.getenv('HOME'), 'Dropbox/uniqueness_cnn_figures')

# SET UP DATA

In [None]:
def map_to_range(arr):
    narr = np.zeros_like(arr, dtype=np.int64)
    vals = np.unique(arr)
    for i, v in enumerate(vals):
        mask = (arr == v)
        narr[mask] = i
    return narr

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
input_shape = (32, 32, 3)
num_classes = 5

y_test = y_test.astype(np.int64).squeeze()
y_train = y_train.astype(np.int64).squeeze()

# labels found to give roughly equal classification performance
set_1_test_mask = np.isin(y_test, np.array([0, 3, 4, 6, 7])).squeeze()
set_1_train_mask = np.isin(y_train, np.array([0, 3, 4, 6, 7])).squeeze()

set_2_test_mask = np.isin(y_test, np.array([1, 2, 5, 8, 9])).squeeze()
set_2_train_mask = np.isin(y_train, np.array([1, 2, 5, 8, 9])).squeeze()

set_1_train_labels = y_train[set_1_train_mask]
x_train_1 = x_train[set_1_train_mask]
set_1_test_labels = y_test[set_1_test_mask]
x_test_1 = x_test[set_1_test_mask]
set_2_train_labels = y_train[set_2_train_mask]
x_train_2 = x_train[set_2_train_mask]
set_2_test_labels = y_test[set_2_test_mask]
x_test_2 = x_test[set_2_test_mask]

y_train_1 = keras.utils.to_categorical(map_to_range(set_1_train_labels), num_classes)
y_train_2 = keras.utils.to_categorical(map_to_range(set_2_train_labels), num_classes)
y_test_1 = keras.utils.to_categorical(map_to_range(set_1_test_labels), num_classes)
y_test_2 = keras.utils.to_categorical(map_to_range(set_2_test_labels), num_classes)

In [None]:
np.unique(y_test, return_index=True)

In [None]:
datagen = ImageDataGenerator(rotation_range=10,
                             width_shift_range=0.05,
                             height_shift_range=0.05,
                             horizontal_flip=True
                            )

In [None]:
def set_up_model(lr=1e-3, trainable=[True, True, True, True], activations=['relu', 'relu', 'relu']):
    keras.backend.clear_session()
    inputs = Input(shape=input_shape)
    x = Conv2D(2, kernel_size=(3,3), strides=(1, 1), activation=activations[0], padding='same', trainable=trainable[0])(inputs)
    x = MaxPooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = Conv2D(4, kernel_size=(3,3), strides=(1, 1), activation=activations[1], padding='same', trainable=trainable[1])(x)
    x = MaxPooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = Conv2D(8, kernel_size=(3,3), strides=(1, 1), activation=activations[2], padding='same', trainable=trainable[2])(x)
    x = MaxPooling2D(pool_size=(2,2), strides=(2,2))(x)
    x = Flatten()(x)
    x = Dense(32, activation='relu', trainable=trainable[3])(x)
    x = Dropout(0.5)(x)
    predictions = Dense(5, activation='softmax')(x)
    model = Model(inputs=inputs, outputs=predictions)
    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adam(lr=lr),
                  metrics=['accuracy'])
    return model

In [None]:
epochs = 250
batch_size = 128

## Train models from scratch

In [None]:
model = set_up_model(trainable=[True, True, True, True])
ckpt_training = ModelCheckpoint(os.path.join(savedir, 'weights_training_dset1_0.h5'), 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=True, 
                                save_weights_only=False,
                                mode='auto', 
                                period=1
                               )

history = model.fit_generator(datagen1.flow(x_train_1, y_train_1, batch_size=batch_size),
                              epochs=epochs,
                              verbose=2,
                              validation_data=(x_test_1, y_test_1),
                              callbacks=[ckpt_training]
                              )
with open(os.path.join(savedir, 'history_training_dset1_%d.pkl'%i), 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
model = set_up_model(trainable=[True, True, True, True])
ckpt_training = ModelCheckpoint(os.path.join(savedir, 'weights_training_dset2_%d.h5'%i), 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=True, 
                                save_weights_only=False,
                                mode='auto', 
                                period=1
                               )

history = model.fit_generator(datagen2.flow(x_train_2, y_train_2, batch_size=batch_size),
                              epochs=epochs,
                              verbose=2,
                              validation_data=(x_test_2, y_test_2),
                              callbacks=[ckpt_training]
                              )
with open(os.path.join(savedir, 'history_training_dset2_%d.pkl'%i), 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
with open(os.path.join(savedir, 'history_training_dset1_0.pkl'), 'rb') as f:
    history_dset1 = pickle.load(f)
with open(os.path.join(savedir, 'history_training_dset2_0.pkl'), 'rb') as f:
    history_dset2 = pickle.load(f)

In [None]:
plt.plot(history_dset1['val_loss'])
plt.plot(history_dset2['val_loss'])

# Freeze 1 conv layer

In [None]:
model = set_up_model(trainable=[False, True, True, True])
model.load_weights(os.path.join(savedir, 'weights_training_dset1_0.h5'))
orig_weights = model.get_weights()
nweights = copy.deepcopy(orig_weights[:2])
for w in orig_weights[2:]:
    nweights.append(glorot_uniform()(w.shape).eval(session=keras.backend.get_session()))
model.set_weights(nweights)

ckpt_training = ModelCheckpoint(os.path.join(savedir, 'weights_transfer_dset1_dset2_0.h5'), 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=True, 
                                save_weights_only=False,
                                mode='auto', 
                                period=1
                               )

history = model.fit_generator(datagen2.flow(x_train_2, y_train_2, batch_size=batch_size),
                              epochs=epochs,
                              verbose=2,
                              validation_data=(x_test_2, y_test_2),
                              callbacks=[ckpt_training]
                              )
with open(os.path.join(savedir, 'history_transfer_dset1_dset2_0.pkl'), 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
model = set_up_model(trainable=[False, True, True, True])
model.load_weights(os.path.join(savedir, 'weights_training_dset2_0.h5'))
orig_weights = model.get_weights()
nweights = copy.deepcopy(orig_weights[:2])
for w in orig_weights[2:]:
    nweights.append(glorot_uniform()(w.shape).eval(session=keras.backend.get_session()))
model.set_weights(nweights)

ckpt_training = ModelCheckpoint(os.path.join(savedir, 'weights_transfer_dset2_dset1_0.h5'), 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=True, 
                                save_weights_only=False,
                                mode='auto', 
                                period=1
                               )

history = model.fit_generator(datagen1.flow(x_train_1, y_train_1, batch_size=batch_size),
                              epochs=epochs,
                              verbose=2,
                              validation_data=(x_test_1, y_test_1),
                              callbacks=[ckpt_training]
                              )
with open(os.path.join(savedir, 'history_transfer_dset2_dset1_0.pkl'), 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
with open(os.path.join(savedir, 'history_transfer_dset1_dset2_0.pkl'), 'rb') as f:
    history_dset1_dset2 = pickle.load(f)
with open(os.path.join(savedir, 'history_transfer_dset2_dset1_0.pkl'), 'rb') as f:
    history_dset2_dset1 = pickle.load(f)

In [None]:
plt.plot(history_dset2_dset1['val_loss'], label='dset2->dset1', color='orange')
plt.plot(history_dset1_dset2['val_loss'], label='dset1->dset2', color='deepskyblue')
plt.legend()

# Freeze 2 conv layers

In [None]:
model = set_up_model(trainable=[False, False, True, True])
model.load_weights(os.path.join(savedir, 'weights_training_dset1_0.h5'))
orig_weights = model.get_weights()
nweights = copy.deepcopy(orig_weights[:4])
for w in orig_weights[4:]:
    nweights.append(glorot_uniform()(w.shape).eval(session=keras.backend.get_session()))
model.set_weights(nweights)

ckpt_training = ModelCheckpoint(os.path.join(savedir, 'weights_transfer_dset1_dset2_freeze_2.h5'), 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=True, 
                                save_weights_only=False,
                                mode='auto', 
                                period=1
                               )

history = model.fit_generator(datagen2.flow(x_train_2, y_train_2, batch_size=batch_size),
                              epochs=epochs,
                              verbose=2,
                              validation_data=(x_test_2, y_test_2),
                              callbacks=[ckpt_training]
                              )
with open(os.path.join(savedir, 'history_transfer_dset1_dset2_freeze_2.pkl'), 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
model = set_up_model(trainable=[False, False, True, True])
model.load_weights(os.path.join(savedir, 'weights_training_dset2_0.h5'))
orig_weights = model.get_weights()
nweights = copy.deepcopy(orig_weights[:4])
for w in orig_weights[4:]:
    nweights.append(glorot_uniform()(w.shape).eval(session=keras.backend.get_session()))
model.set_weights(nweights)
    
ckpt_training = ModelCheckpoint(os.path.join(savedir, 'weights_transfer_dset2_dset1_freeze_2.h5'), 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=True, 
                                save_weights_only=False,
                                mode='auto', 
                                period=1
                               )

history_training = model.fit_generator(datagen1.flow(x_train_1, y_train_1, batch_size=batch_size),
                              epochs=epochs,
                              verbose=2,
                              validation_data=(x_test_1, y_test_1),
                              callbacks=[ckpt_training]
                              )
with open(os.path.join(savedir, 'history_transfer_dset2_dset1_freeze_2.pkl'), 'wb') as f:
    pickle.dump(history_training.history, f)

In [None]:
with open(os.path.join(savedir, 'history_transfer_dset1_dset1_freeze_2.pkl'), 'rb') as f:
    history_dset1_dset1 = pickle.load(f)
with open(os.path.join(savedir, 'history_transfer_dset2_dset1_freeze_2.pkl'), 'rb') as f:
    history_dset2_dset1 = pickle.load(f)

In [None]:
plt.plot(history_dset2_dset1['val_loss'], label='dset2->dset1', color='orange')
plt.plot(history_dset1_dset2['val_loss'], label='dset1->dset2', color='deepskyblue')
plt.legend()

# Freeze 3 conv layers

In [None]:
model = set_up_model(trainable=[False, False, False, True])
model.load_weights(os.path.join(savedir, 'weights_training_dset2_0.h5'))
orig_weights = model.get_weights()
nweights = copy.deepcopy(orig_weights[:6])
for w in orig_weights[6:]:
    nweights.append(glorot_uniform()(w.shape).eval(session=keras.backend.get_session()))
model.set_weights(nweights)

ckpt_training = ModelCheckpoint(os.path.join(savedir, 'weights_transfer_dset2_dset_1freeze_3.h5'), 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=True, 
                                save_weights_only=False,
                                mode='auto', 
                                period=1
                               )

history = model.fit_generator(datagen1.flow(x_train_1, y_train_1, batch_size=batch_size),
                              epochs=epochs,
                              verbose=2,
                              validation_data=(x_test_1, y_test_1),
                              callbacks=[ckpt_training]
                              )
with open(os.path.join(savedir, 'history_transfer_dset2_dset1_freeze_3.pkl'), 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
model = set_up_model(trainable=[False, False, False, True])
model.load_weights(os.path.join(savedir, 'weights_training_dset1_0.h5'))
orig_weights = model.get_weights()
nweights = copy.deepcopy(orig_weights[:6])
for w in orig_weights[6:]:
    nweights.append(glorot_uniform()(w.shape).eval(session=keras.backend.get_session()))
model.set_weights(nweights)

ckpt_training = ModelCheckpoint(os.path.join(savedir, 'weights_transfer_dset1_dset2_freeze_3.h5'), 
                                monitor='val_loss', 
                                verbose=0, 
                                save_best_only=True, 
                                save_weights_only=False,
                                mode='auto', 
                                period=1
                               )

history = model.fit_generator(datagen2.flow(x_train_2, y_train_2, batch_size=batch_size),
                              epochs=epochs,
                              verbose=2,
                              validation_data=(x_test_2, y_test_2),
                              callbacks=[ckpt_training]
                              )
with open(os.path.join(savedir, 'history_transfer_dset1_dset2_freeze_3.pkl'), 'wb') as f:
    pickle.dump(history.history, f)

In [None]:
with open(os.path.join(savedir, 'history_transfer_dset1_dset2_0.pkl'), 'rb') as f:
    history_dset1_dset2_1 = pickle.load(f)
with open(os.path.join(savedir, 'history_transfer_dset2_dset1_0.pkl'), 'rb') as f:
    history_dset2_dset1_1 = pickle.load(f)

In [None]:
with open(os.path.join(savedir, 'history_transfer_dset1_dset2_freeze_2.pkl'), 'rb') as f:
    history_dset1_dset2_2 = pickle.load(f)
with open(os.path.join(savedir, 'history_transfer_dset2_dset1_freeze_2.pkl'), 'rb') as f:
    history_dset2_dset1_2 = pickle.load(f)

In [None]:
with open(os.path.join(savedir, 'history_transfer_dset1_dset2_freeze_3.pkl'), 'rb') as f:
    history_dset1_dset2_3 = pickle.load(f)
with open(os.path.join(savedir, 'history_transfer_dset2_dset1_freeze_3.pkl'), 'rb') as f:
    history_dset2_dset1_3 = pickle.load(f)