In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time
import pandas as pd
from tensorflow.keras import layers, models, constraints, optimizers
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score

# ======================
# Custom Constraints
# ======================
class ClipConstraint(constraints.Constraint):
    """Clips weights between `min_val` and `max_val`."""
    def __init__(self, min_val, max_val):
        self.min_val = min_val
        self.max_val = max_val

    def __call__(self, w):
        return tf.clip_by_value(w, self.min_val, self.max_val)

    def get_config(self):
        return {'min_val': self.min_val, 'max_val': self.max_val}

# ======================
# Custom Activation Layers
# ======================


class SGBlend(layers.Layer):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        # Learnable parameters
        self.alpha = self.add_weight(
            name="alpha",
            shape=(1,),
            initializer="zeros",  # Start with GELU (alpha=0)
            constraint=tf.keras.constraints.MinMaxNorm(0.0, 1.0)
        )

        self.beta = self.add_weight(
            name="beta",
            shape=(1,),
            initializer="glorot_uniform",
            constraint=ClipConstraint(0.1, 10.0)
        )

        self.gamma = self.add_weight(
            name="gamma",
            shape=(1,),
            initializer="zeros"
        )
        super().build(input_shape)

    def call(self, x):

        # GELU approximation: x * σ(1.702x)
        gelu_term = x * tf.nn.sigmoid(1.702 * x)
        # SSwish: x * σ(beta x) - gamma
        sswish_term = x * tf.nn.sigmoid(self.beta * x) - self.gamma
        # Blend using alpha (sigmoid to enforce [0,1])
        alpha = tf.nn.sigmoid(self.alpha)  
        return alpha * sswish_term + (1 - alpha) * gelu_term

class SSwish(layers.Layer):
    """Symmetric Swish"""
    def build(self, input_shape):
        self.beta = self.add_weight(
            name='beta',
            shape=(1,),
            initializer='glorot_uniform',
            constraint=ClipConstraint(0.1, 10)
        )
        self.gamma = self.add_weight(
            name='gamma',
            shape=(1,),
            initializer='zeros'
        )
        super().build(input_shape)

    def call(self, x):
        return x * tf.nn.sigmoid(self.beta * x) - self.gamma

class GELU(layers.Layer):
    """Gaussian Error Linear Unit"""
    def call(self, x):
        return tf.nn.gelu(x)

class Mish(layers.Layer):
    """Mish Activation"""
    def call(self, x):
        return x * tf.math.tanh(tf.math.softplus(x))

# ======================
# Load and Preprocess Data
# ======================

# We'll use IMDB dataset for sentiment analysis
def load_data():
    print("Loading IMDB dataset...")
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=10000)
    
    # Convert sequence of integers to words
    word_index = tf.keras.datasets.imdb.get_word_index()
    reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
    
    # Decode sequences
    def decode_review(text):
        return ' '.join([reverse_word_index.get(i - 3, '?') for i in text])
    
    # Decode the full dataset
    x_train_decoded = [decode_review(review) for review in x_train]
    x_test_decoded = [decode_review(review) for review in x_test]
    
    # Tokenize the text
    tokenizer = Tokenizer(num_words=10000)
    tokenizer.fit_on_texts(x_train_decoded + x_test_decoded)
    
    # Convert to sequences
    x_train_seq = tokenizer.texts_to_sequences(x_train_decoded)
    x_test_seq = tokenizer.texts_to_sequences(x_test_decoded)
    
    # Pad sequences
    max_len = 200
    x_train_pad = pad_sequences(x_train_seq, maxlen=max_len, padding='post')
    x_test_pad = pad_sequences(x_test_seq, maxlen=max_len, padding='post')
    
    # No need to slice labels as we're using the full dataset
    
    return x_train_pad, y_train, x_test_pad, y_test, tokenizer.word_index

# ======================
# Custom Transformer Block
# ======================

def transformer_block(inputs, embed_dim, num_heads, ff_dim, activation, rate=0.1):
    """Transformer block with configurable activation function."""
    # Multi-head attention
    attention_output = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim
    )(inputs, inputs)
    attention_output = layers.Dropout(rate)(attention_output)
    out1 = layers.LayerNormalization(epsilon=1e-6)(inputs + attention_output)

    # Feed Forward Network
    if isinstance(activation, layers.Layer):
        # For custom activation layers (which are layers themselves)
        ffn_output = layers.Dense(ff_dim)(out1)
        ffn_output = activation(ffn_output)
        ffn_output = layers.Dense(embed_dim)(ffn_output)
    else:
        # For built-in activation functions
        ffn_output = layers.Dense(ff_dim, activation=activation)(out1)
        ffn_output = layers.Dense(embed_dim)(ffn_output)

    ffn_output = layers.Dropout(rate)(ffn_output)
    return layers.LayerNormalization(epsilon=1e-6)(out1 + ffn_output)

# ======================
# Define Model Architecture
# ======================

def create_model(activation, vocab_size, embed_dim=32, num_heads=2, ff_dim=32):
    """
    Create a small transformer model with the specified activation function.
    """
    inputs = layers.Input(shape=(200,))
    embedding_layer = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
    x = embedding_layer(inputs)
    x = layers.Dropout(0.1)(x)

    # Add transformer blocks
    x = transformer_block(x, embed_dim, num_heads, ff_dim, activation)
    x = transformer_block(x, embed_dim, num_heads, ff_dim, activation)

    # Global average pooling
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.1)(x)

    # Final classification layer
    if isinstance(activation, layers.Layer):
        x = layers.Dense(20)(x)
        x = activation(x)
    else:
        x = layers.Dense(20, activation=activation)(x)

    outputs = layers.Dense(1, activation="sigmoid")(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

# ======================
# Main Experiment
# ======================

def run_experiment():
    # Load and preprocess data
    x_train, y_train, x_test, y_test, word_index = load_data()

    # Define activations to compare
    activations = {
        'SGBlend': SGBlend(),
        'GELU': GELU(),
        'Swish': tf.nn.swish,
        'SSwish': SSwish(),
        'ReLU': tf.nn.relu,
        'Mish': Mish(),
    }

    # Training parameters
    epochs = 15
    batch_size = 64
    results = []

    # Early stopping callback
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=3,
        restore_best_weights=True
    )

    # For measuring neuron "deadness" in transformers
    def measure_neuron_activity(model, x_data, layer_idx=-3):
        # Create a submodel to extract intermediate activations
        layer_outputs = [layer.output for layer in model.layers[:layer_idx]]
        submodel = tf.keras.Model(inputs=model.input, outputs=layer_outputs)

        # Get activations
        activations = submodel.predict(x_data[:100])
        penultimate_activations = activations[-1]  # Get the last layer

        # Calculate percentage of dead neurons
        if len(penultimate_activations.shape) > 2:
            # For tensors with more than 2 dimensions, flatten the spatial dimensions
            dead_neuron_pct = np.mean(np.max(penultimate_activations, axis=(0, 1)) < 1e-3)
        else:
            # For 2D tensors (batch_size, features)
            dead_neuron_pct = np.mean(np.max(penultimate_activations, axis=0) < 1e-3)

        return dead_neuron_pct

    with tf.device('/GPU:0'):  # Use GPU if available
        for name, activation in activations.items():
            print(f"\nTraining with {name} activation...")

            # Create model
            vocab_size = len(word_index) + 1
            model = create_model(activation, vocab_size)

            # Compile model
            model.compile(
                optimizer=optimizers.Adam(learning_rate=1e-4),
                loss="binary_crossentropy",
                metrics=["accuracy"]
            )

            # Train model
            start_time = time.time()
            history = model.fit(
                x_train, y_train,
                epochs=epochs,
                batch_size=batch_size,
                validation_split=0.1,
                callbacks=[early_stopping],
                verbose=1
            )
            training_time = time.time() - start_time

            # Evaluate model
            y_pred = (model.predict(x_test) > 0.5).astype(int).flatten()
            test_acc = accuracy_score(y_test, y_pred)
            test_f1 = f1_score(y_test, y_pred)

            # Measure dead neurons
            dead_neurons = measure_neuron_activity(model, x_train)

            # Store results
            results.append({
                'name': name,
                'test_acc': test_acc,
                'test_f1': test_f1,
                'training_time': training_time,
                'dead_neurons': dead_neurons,
                'history': history.history
            })

            print(f"{name} - Test Acc: {test_acc:.4f}, F1: {test_f1:.4f}, Time: {training_time:.1f}s, Dead Neurons: {dead_neurons:.2%}")

    return results

# ======================
# Visualize Results
# ======================

def visualize_results(results):
    plt.figure(figsize=(20, 15))

    # 1. Validation Accuracy Plot
    plt.subplot(2, 2, 1)
    for res in results:
        if 'val_accuracy' in res['history']:
            plt.plot(res['history']['val_accuracy'], label=res['name'])
    plt.title('Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    # 2. Test Performance (Accuracy & F1)
    plt.subplot(2, 2, 2)
    names = [res['name'] for res in results]
    acc = [res['test_acc'] for res in results]
    f1 = [res['test_f1'] for res in results]

    x = np.arange(len(names))
    width = 0.35

    plt.bar(x - width/2, acc, width, label='Accuracy')
    plt.bar(x + width/2, f1, width, label='F1 Score')

    plt.title('Test Performance')
    plt.xlabel('Activation Function')
    plt.ylabel('Score')
    plt.xticks(x, names, rotation=45)
    plt.legend()

    # 3. Training Time
    plt.subplot(2, 2, 3)
    times = [res['training_time'] for res in results]
    plt.bar(names, times)
    plt.title('Training Time')
    plt.xlabel('Activation Function')
    plt.xticks(rotation=45)
    plt.ylabel('Seconds')

    # 4. Dead Neurons Percentage
    plt.subplot(2, 2, 4)
    dead = [res['dead_neurons'] for res in results]
    plt.bar(names, dead)
    plt.title('Dead Neurons Percentage')
    plt.xlabel('Activation Function')
    plt.xticks(rotation=45)
    plt.ylabel('Percentage')

    plt.tight_layout()
    plt.savefig('nlp_activation_comparison.png')
    plt.show()

# ======================
# Final Comparison
# ======================

def output_comparison_table(results):
    # Sort by test accuracy
    sorted_results = sorted(results, key=lambda x: x['test_acc'], reverse=True)

    # Create DataFrame
    df = pd.DataFrame([
        {
            'Activation': res['name'],
            'Test Accuracy': f"{res['test_acc']:.4f}",
            'F1 Score': f"{res['test_f1']:.4f}",
            'Training Time (s)': f"{res['training_time']:.1f}",
            'Dead Neurons (%)': f"{res['dead_neurons']:.2%}"
        }
        for res in sorted_results
    ])

    return df

# ======================
# Main Execution
# ======================

if __name__ == "__main__":
    # Set random seeds for reproducibility
    np.random.seed(42)
    tf.random.set_seed(42)

    results = run_experiment()
    visualize_results(results)
    comparison_table = output_comparison_table(results)

    print("\nActivation Function Comparison (sorted by test accuracy):")
    print(comparison_table)
    comparison_table.to_csv('nlp_activation_comparison.csv', index=False)