# Transfer Learning
This notebook is part of the [SachsLab Workshop for Intracranial Neurophysiology and Deep Learning](https://github.com/SachsLab/IntracranialNeurophysDL).
https://www.tensorflow.org/alpha/tutorials/images/transfer_learning
* Freeze layers, update only first layer, then unfreeze and update

Run the first two cells to normalize Local / Colab environments, then proceed below for the lesson.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import os
try:
    # See if we are running on google.colab
    import google.colab
    from google.colab import files
    os.chdir('..')
    if not (Path.home() / '.kaggle').is_dir():
        # Configure kaggle
        files.upload()  # Find the kaggle.json file in your ~/.kaggle directory.
        !pip install -q kaggle
        !mkdir -p ~/.kaggle
        !mv kaggle.json ~/.kaggle/
        !chmod 600 ~/.kaggle/kaggle.json
    if Path.cwd().stem != 'IntracranialNeurophysDL':
        if not (Path.cwd() / 'IntracranialNeurophysDL').is_dir():
            # Download the workshop repo and change to its directory
            !git clone --recursive https://github.com/SachsLab/IntracranialNeurophysDL.git
        os.chdir('IntracranialNeurophysDL')
    IN_COLAB = True
    # Setup tensorflow 2.0
    !pip install -q tensorflow-gpu==2.0.0-alpha0
except ModuleNotFoundError:
    IN_COLAB = False
    import sys
    if Path.cwd().stem == 'notebooks':
        os.chdir(Path.cwd().parent)
    # Make sure the kaggle executable is on the PATH
    os.environ['PATH'] = os.environ['PATH'] + ';' + str(Path(sys.executable).parent / 'Scripts')

# Try to clear any logs from previous runs
if (Path.cwd() / 'logs').is_dir():
    import shutil
    try:
        shutil.rmtree(str(Path.cwd() / 'logs'))
    except PermissionError:
        print("Unable to remove logs directory.")

# Additional imports
import tensorflow as tf
import datetime
import numpy as np
import matplotlib.pyplot as plt
from indl import enable_plotly_in_cell
%load_ext tensorboard.notebook


In [None]:
#Reset the model weights to what they were before training
model.set_weights(initial_weights)

def replace_input_layers(old_model, new_input_channels):
    inputs = tf.keras.layers.Input(shape=(None, new_input_channels))
    _y = tf.keras.layers.Conv1D(N_SOURCES, 1, use_bias=False)(inputs)  # Spatial filter.
    for layer_ix, layer in enumerate(old_model.layers):
        if layer_ix > 1:
            _y = layer(_y)
    return tf.keras.Model(inputs, _y)

# Choose another participant, replace the input layers to match new input size, and retrain the model.
X, Y, ax_info = load_faces_houses(datadir, 'de', feature_set='full')
ds_train, ds_valid, n_train = get_ds_train_valid(X, Y, p_train=PTRAIN, batch_size=BATCH_SIZE, max_offset=100)
xfer_model = replace_input_layers(model, X.shape[-1])
xfer_model.compile(loss='sparse_categorical_crossentropy', optimizer='Nadam', metrics=['accuracy'])
xfer_model.summary()

In [None]:
history = xfer_model.fit(x=ds_train,  
                         epochs=N_EPOCHS, 
                         validation_data=ds_valid,
                         verbose=1)

In [None]:
# Iterate through each participant
participant_names = ['ja', 'ca', 'wc', 'de', 'zt', 'fp']  # , 'mv'
INPUT_EPOCHS = 100  # Can use a lot here because it's quite fast.
FINE_TUNE_EPOCHS = 50
BATCH_SIZE = 20
input_hists = []
full_hists = []

for p_ix, p_name in enumerate(participant_names):
    
    # Load their data
    X, Y, ax_info = load_faces_houses(datadir, p_name, feature_set='full')
    ds_train, ds_valid, n_train = get_ds_train_valid(X, Y, p_train=PTRAIN, batch_size=BATCH_SIZE, max_offset=100)
    
    # Make a new model with the proper input size
    xfer_model = replace_input_layers(xfer_model, X.shape[-1])
    
    # Freeze layers other than input layers.
    for layer_ix, layer in enumerate(xfer_model.layers):
        if layer_ix > 1:
            layer.trainable=False
    
    # Train for a INPUT_EPOCHS epochs to update input layers only
    xfer_model.compile(loss='sparse_categorical_crossentropy', optimizer='Nadam', metrics=['accuracy'])
    temp = xfer_model.fit(x=ds_train,  
                          epochs=INPUT_EPOCHS, 
                          validation_data=ds_valid,
                          verbose=1)
    input_hists.append(temp)
    
    # Unfreeze all layers
    for layer_ix, layer in enumerate(xfer_model.layers):
        layer.trainable=True
    
    # Fine-tuning: Train for longer at a much lower rate
    xfer_model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Nadam(lr=1e-5), metrics=['accuracy'])
    temp = xfer_model.fit(x=ds_train,  
                          epochs=FINE_TUNE_EPOCHS, 
                          validation_data=ds_valid,
                          verbose=1)
    full_hists.append(temp)

# Save the model
xfer_model.save(datadir / 'converted' / 'faces_basic' / 'xfer_model_full.h5')

## Transfer the model to unseen data

In [None]:
# Save what we just trained
if IN_COLAB:
    files.download(datadir / 'converted' / 'faces_basic' / 'xfer_model_full.h5')

In [None]:
X, Y, ax_info = load_faces_houses(datadir, 'mv', feature_set='full')
ds_train, ds_valid, n_train = get_ds_train_valid(X, Y, p_train=0.5, batch_size=BATCH_SIZE, max_offset=100)

# Make a new model with the proper input size
xfer_model = replace_input_layers(xfer_model, X.shape[-1])
xfer_model.compile(loss='sparse_categorical_crossentropy', optimizer='Nadam', metrics=['accuracy'])

In [None]:
log_dir = Path.cwd() / "logs" / datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_cb = tf.keras.callbacks.TensorBoard(str(log_dir), histogram_freq=1)
history = xfer_model.fit(x=ds_train,  
                         epochs=50,
                         validation_data=ds_valid,
                         callbacks=[tensorboard_cb],
                         verbose=1)
%tensorboard --logdir={str(log_dir)}
