In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense, Dropout, MultiHeadAttention, Add, LayerNormalization
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
from tensorflow.keras.callbacks import EarlyStopping

seed = 123
tf.keras.utils.set_random_seed(seed)

In [3]:
%run utils.ipynb

In [None]:
def training_transformer_autoencoder(
    train_data,
    val_data,
    n_heads = 8,
    d_model = 128,
    num_encoder_layers = 1,
    num_decoder_layers = 1,
    feed_forward_dim = 256,
    dropout_rate = 0.2,
    learning_rate = 0.0001,
    n_epochs = 500,
    batch_size = 32,
    window_size = 20,
    metric = 'mse',
    plot = True,
    save = True
):

    # Get the number of features from the training data
    number_of_features = train_data.shape[2]
    # Optimizer and loss
    opt = Adam(learning_rate = learning_rate)
    loss_metric = MeanSquaredError() if metric == 'mse' else MeanAbsoluteError()

    # Positional Encoding
    def positional_encoding(position, d_mod):
        """
        Generate positional encoding for the input sequences
        This helps the model encode the order of the sequence elements
        """
        angle_rads = np.arange(position)[:, np.newaxis] / np.power(
            10000, (2 * (np.arange(d_mod)[np.newaxis, :] // 2)) / np.float32(d_mod)
        )
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) # Apply sine to even indices
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) # Apply cosine to odd indices
        pos_enc = angle_rads[np.newaxis, ...]
        return tf.cast(pos_enc, dtype = tf.float32)

    # Transformer Block
    def transformer_block(x_input, num_heads, feed_forward_d, d_mod, dropout):
        """
        Transformer block consisting of multi-head self-attention,
        feed-forward layers, layer normalization, and dropout
        """
        # Multi-Head Attention
        attention_output = MultiHeadAttention(num_heads = num_heads, key_dim = d_mod)(x_input, x_input)
        attention_output = Dropout(dropout)(attention_output)
        out1 = Add()([x_input, attention_output]) # Add residual connection
        out1 = LayerNormalization(epsilon = 1e-6)(out1) # Normalize the output

        # Feed-Forward Network
        ffn_output = Dense(feed_forward_d, activation = 'relu')(out1)
        ffn_output = Dense(d_mod)(ffn_output)
        ffn_output = Dropout(dropout)(ffn_output)
        out2 = Add()([out1, ffn_output]) # Add residual connection
        out2 = LayerNormalization(epsilon = 1e-6)(out2) # Normalize the output

        return out2

    # Add Noise
    def add_noise(data, noise_level):
        """
        Add Gaussian noise to the data for denoising autoencoder training
        """
        noisy_data = data + np.random.normal(scale = noise_level, size = data.shape)
        return noisy_data

    # Build the Autoencoder Model
    inputs = Input(shape = (window_size, number_of_features))

    # Apply Positional Encoding
    pos_encoding = positional_encoding(window_size, d_model)
    x = Dense(d_model)(inputs) # Map inputs to d_model dimensions
    x += pos_encoding # Add positional encoding

    # Encoder
    for _ in range(num_encoder_layers):
        x = transformer_block(x, n_heads, feed_forward_dim, d_model, dropout_rate)
    encoder_output = x # Final encoder output

    # Decoder
    x = encoder_output
    for _ in range(num_decoder_layers):
        x = transformer_block(x, n_heads, feed_forward_dim, d_model, dropout_rate)

    # Final Dense Layer
    outputs = Dense(number_of_features, activation = 'sigmoid')(x) # Map back to original feature space

    # Compile the Model
    autoencoder = Model(inputs, outputs)
    autoencoder.compile(optimizer = opt, loss = loss_metric)

    # Early stopping
    early_stopping = EarlyStopping(monitor = 'val_loss', patience = 50, verbose = 1, mode = 'min', restore_best_weights = True)

    # Add noise to training data
    noisy_train_data = add_noise(train_data, noise_level = 0.005)

    # Train Model
    if val_data is not None:
        history = autoencoder.fit(
            noisy_train_data, val_data,
            epochs = n_epochs,
            batch_size = batch_size,
            shuffle = False,
            validation_data = (val_data, val_data),
            callbacks = [early_stopping]
        )
    else:
        history = autoencoder.fit(
            noisy_train_data, train_data,
            epochs = n_epochs,
            batch_size = batch_size,
            shuffle = False
        )

    # Generate model path using hyperparameters for easy identification
    lr = str(learning_rate).replace('.', '')
    dr = str(dropout_rate).replace('.', '')
    model_path = f'transformer_autoencoder_{num_encoder_layers}_{n_heads}_{d_model}_{feed_forward_dim}_{dr}_{lr}_{n_epochs}_{metric}_{batch_size}'

    # Plot training and validation loss if specified
    if plot:
        plt.figure(figsize = (10, 5))
        plt.plot(history.history['loss'], label = 'Training Loss')
        if 'val_loss' in history.history:
            plt.plot(history.history['val_loss'], label = 'Validation Loss')
        plt.title('Model Loss')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(loc='upper right')
        plt.tight_layout()
        loss_save_path = 'losses/' + model_path
        plt.savefig(loss_save_path + '.png')
        plt.close()

    # Save model if specified
    if save:
        os.makedirs('models', exist_ok=True)
        model_save_path = 'models/' + model_path + '.pkl'
        with open(model_save_path, 'wb') as file:
            pickle.dump(autoencoder, file)

    return history, autoencoder

In [None]:
def compute_mas_importance(model, dataset):
    """
    Compute MAS importance for each trainable parameter
    We pass data, compute the squared L2 norm of the outputs,
    and accumulate the absolute gradient for each parameter
    """
    importance = [tf.zeros_like(var, dtype = tf.float32) for var in model.trainable_variables]
    n_batches = 0

    for x_batch in dataset:
        n_batches += 1
        with tf.GradientTape() as tape:
            outputs = model(x_batch, training = False)
            loss = tf.reduce_sum(tf.square(outputs))
        grads = tape.gradient(loss, model.trainable_variables)
        for i, g in enumerate(grads):
            if g is not None:
                importance[i] += tf.abs(g)

    # Average over all batches
    for i in range(len(importance)):
        importance[i] /= float(n_batches)

    return importance

In [None]:
def get_current_weights(model):
    # Return a list of tf.Tensors that copy the model's current trainable variables
    return [tf.identity(v) for v in model.trainable_variables]

In [None]:
def mas_finetune_on_new_task(
    model_path,
    old_data,         # Data from old training to preserve performance
    new_data,         # Data for continual learning
    lambda_ = 1.0,    # MAS regularization weight
    n_epochs = 50,
    batch_size = 32,
    learning_rate = 1e-4
):
    """
    Load a previously trained Transformer autoencoder and
    applies MAS-based fine-tuning on a new task
    """
    save_path = "models/transformer_autoencoder_MAS.pkl"

    # Optimizer and loss
    opt = Adam(learning_rate = learning_rate)
    loss_metric = MeanSquaredError()

    # Load trained model
    with open(model_path, 'rb') as f:
        model = pickle.load(f)

    # Compute old importance using old data and store params
    old_dataset = tf.data.Dataset.from_tensor_slices(old_data).batch(batch_size)
    old_importance = compute_mas_importance(model, old_dataset)
    old_params = get_current_weights(model)

    # Create dataset for the new task
    new_dataset = tf.data.Dataset.from_tensor_slices((new_data, new_data)).batch(batch_size)
    optimizer = Adam(learning_rate = learning_rate)
    loss_metric = MeanSquaredError()
    loss_history = []

    # Train with MAS penalty
    for epoch in range(n_epochs):
        epoch_loss = 0.0
        num_batches = 0
        for x_batch, y_batch in new_dataset:
            with tf.GradientTape() as tape:
                predictions = model(x_batch, training = True)
                main_loss = loss_metric(y_batch, predictions)
                # MAS regularization
                mas_reg = 0.0
                for param, old_param, imp in zip(model.trainable_variables, old_params, old_importance):
                    mas_reg += tf.reduce_sum(imp * tf.square(param - old_param))
                total_loss = main_loss + lambda_ * mas_reg
            # Backprop
            grads = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            epoch_loss += total_loss.numpy()
            num_batches += 1
        epoch_loss /= num_batches
        loss_history.append(epoch_loss)
        print(f"Epoch {epoch+1}/{n_epochs} - loss: {epoch_loss:.6f}")

    # Plot training loss
    plt.figure(figsize = (10, 5))
    plt.plot(range(1, n_epochs + 1), loss_history,  label = 'Continual Learning - Training Loss')
    plt.title('Updated model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.grid()
    plt.savefig("losses/transformer_autoencoder_MAS.png")
    #plt.show()
    plt.close()

    # Save updated model
    with open(save_path, 'wb') as f:
        pickle.dump(model, f)

    return model