In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time
import pandas as pd
from tensorflow.keras import layers, models, constraints, optimizers
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score
import tensorflow_datasets as tfds

# ======================
# Custom Constraints (Same as Before)
# ======================
class ClipConstraint(constraints.Constraint):
    def __init__(self, min_val, max_val):
        self.min_val = min_val
        self.max_val = max_val

    def __call__(self, w):
        return tf.clip_by_value(w, self.min_val, self.max_val)

    def get_config(self):
        return {'min_val': self.min_val, 'max_val': self.max_val}

# ======================
# Custom Activation Layers (Same as Before)
# ======================
class SGBlend(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        # Learnable parameters
        self.alpha = self.add_weight(
            name="alpha",
            shape=(1,),
            initializer="zeros",  # Start with GELU (alpha=0)
            constraint=tf.keras.constraints.MinMaxNorm(0.0, 1.0)
        )

        self.beta = self.add_weight(
            name="beta",
            shape=(1,),
            initializer="glorot_uniform",
            constraint=ClipConstraint(0.1, 10.0)
        )

        self.gamma = self.add_weight(
            name="gamma",
            shape=(1,),
            initializer="zeros"
        )

        super().build(input_shape)

    def call(self, x):
        # GELU approximation: x * σ(1.702x)
        gelu_term = x * tf.nn.sigmoid(1.702 * x)
        # SSwish: x * σ(beta x) - gamma
        sswish_term = x * tf.nn.sigmoid(self.beta * x) - self.gamma
        # Blend using alpha (sigmoid to enforce [0,1])
        alpha = tf.nn.sigmoid(self.alpha)  # α ∈ [0,1]
        return alpha * sswish_term + (1 - alpha) * gelu_term

class SSwish(layers.Layer):
    """Symmetric Swish"""
    def build(self, input_shape):
        self.beta = self.add_weight(
            name='beta',
            shape=(1,),
            initializer='glorot_uniform',
            constraint=ClipConstraint(0.1, 10)
        )
        self.gamma = self.add_weight(
            name='gamma',
            shape=(1,),
            initializer='zeros'
        )
        super().build(input_shape)

    def call(self, x):
        return x * tf.nn.sigmoid(self.beta * x) - self.gamma

class GELU(layers.Layer):
    """Gaussian Error Linear Unit"""
    def call(self, x):
        return tf.nn.gelu(x)

class Mish(layers.Layer):
    """Mish Activation"""
    def call(self, x):
        return x * tf.math.tanh(tf.math.softplus(x))

# ======================
# Positional Encoding (Same as Before)
# ======================
def positional_encoding(length, depth):
    depth = depth/2
    positions = np.arange(length)[:, np.newaxis]
    depths = np.arange(depth)[np.newaxis, :]/depth
    angle_rates = 1 / (10000**depths)
    angle_rads = positions * angle_rates
    pos_encoding = np.concatenate([np.sin(angle_rads), np.cos(angle_rads)], axis=-1)
    return tf.cast(pos_encoding, dtype=tf.float32)

class PositionalEmbedding(layers.Layer):
    def __init__(self, vocab_size, d_model, max_length=512):
        super().__init__()
        self.d_model = d_model
        self.embedding = layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.pos_encoding = positional_encoding(max_length, d_model)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[tf.newaxis, :length, :]
        return x

# ======================
# BERT Components
# ======================
def encoder_block(inputs, embed_dim, num_heads, ff_dim, activation, rate=0.1):
    """Transformer encoder block with configurable activation."""
    # Self-Attention
    attn_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)(inputs, inputs)
    attn_output = layers.Dropout(rate)(attn_output)
    out1 = layers.LayerNormalization(epsilon=1e-6)(inputs + attn_output)

    # Feed Forward Network
    if isinstance(activation, layers.Layer):
        ffn = models.Sequential([
            layers.Dense(ff_dim),
            activation,
            layers.Dense(embed_dim)
        ])
    else:
        ffn = models.Sequential([
            layers.Dense(ff_dim, activation=activation),
            layers.Dense(embed_dim)
        ])

    ffn_output = ffn(out1)
    ffn_output = layers.Dropout(rate)(ffn_output)
    return layers.LayerNormalization(epsilon=1e-6)(out1 + ffn_output)

# ======================
# Load and Preprocess IMDB Dataset
# ======================
def load_imdb_data(max_samples=25000, max_length=256):
    """Load and preprocess IMDB movie reviews dataset."""
    print("Loading IMDB dataset...")

    # Load dataset
    dataset, info = tfds.load('imdb_reviews', with_info=True, as_supervised=True)
    train_data, test_data = dataset['train'], dataset['test']

    # Take subsets
    train_data = train_data.take(max_samples)
    test_data = test_data.take(max_samples//2)

    # Process texts and labels
    train_texts, train_labels = [], []
    test_texts, test_labels = [], []

    for text, label in tfds.as_numpy(train_data):
        train_texts.append('[CLS] ' + text.decode('utf-8'))
        train_labels.append(label)

    for text, label in tfds.as_numpy(test_data):
        test_texts.append('[CLS] ' + text.decode('utf-8'))
        test_labels.append(label)

    # Split into train/validation
    split = int(len(train_texts) * 0.9)
    x_train, x_val = train_texts[:split], train_texts[split:]
    y_train, y_val = train_labels[:split], train_labels[split:]

    # Tokenizer
    tokenizer = Tokenizer(oov_token='[UNK]')
    tokenizer.fit_on_texts(x_train)

    # Convert to sequences
    x_train_seq = tokenizer.texts_to_sequences(x_train)
    x_val_seq = tokenizer.texts_to_sequences(x_val)
    x_test_seq = tokenizer.texts_to_sequences(test_texts)

    # Pad sequences
    x_train_pad = pad_sequences(x_train_seq, maxlen=max_length, padding='post')
    x_val_pad = pad_sequences(x_val_seq, maxlen=max_length, padding='post')
    x_test_pad = pad_sequences(x_test_seq, maxlen=max_length, padding='post')

    # Convert labels to arrays
    y_train = np.array(y_train)
    y_val = np.array(y_val)
    y_test = np.array(test_labels)

    print(f"Vocabulary size: {len(tokenizer.word_index)+1}")
    print(f"Training samples: {len(x_train_pad)}")
    print(f"Validation samples: {len(x_val_pad)}")

    return (x_train_pad, y_train), (x_val_pad, y_val), tokenizer

# ======================
# BERT Model Architecture
# ======================
def create_bert_model(activation, vocab_size, max_length=256,
                     d_model=64, num_heads=2, ff_dim=128, num_layers=2):
    """Create a BERT-like model for text classification."""
    inputs = layers.Input(shape=(max_length,))
    x = PositionalEmbedding(vocab_size, d_model, max_length)(inputs)

    # Stack encoder blocks
    for _ in range(num_layers):
        x = encoder_block(x, d_model, num_heads, ff_dim, activation)

    # Use [CLS] token for classification
    cls_token = x[:, 0, :]
    outputs = layers.Dense(1, activation='sigmoid')(cls_token)

    return models.Model(inputs=inputs, outputs=outputs)

# ======================
# Training and Evaluation
# ======================
def run_experiment():
    # Load data
    (x_train, y_train), (x_val, y_val), tokenizer = load_imdb_data()

    # Define activations to compare
    activations = {
        'SGBlend': SGBlend(),
        'GELU': GELU(),
        'Swish': tf.nn.swish,
        'SSwish': SSwish(),
        'ReLU': tf.nn.relu,
        'Mish': Mish(),
    }

    # Training parameters
    epochs = 25
    batch_size = 32
    results = []

    # Early stopping
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=2,
        restore_best_weights=True
    )

    # Neuron activity measurement
    def measure_neuron_activity(model, x_data, layer_idx=-2):
        layer_outputs = [layer.output for layer in model.layers[:layer_idx]]
        activations_model = tf.keras.Model(inputs=model.inputs, outputs=layer_outputs)

        sample_data = x_data[:100]
        activations = activations_model.predict([sample_data])
        penultimate = activations[-1]

        if len(penultimate.shape) > 2:
            dead = np.mean(np.max(penultimate, axis=(0,1)) < 1e-3)
        else:
            dead = np.mean(np.max(penultimate, axis=0) < 1e-3)

        return np.mean(dead)

    # Training loop
    try:
        with tf.device('/GPU:0'):
            for name, activation in activations.items():
                print(f"\nTraining with {name} activation...")

                model = create_bert_model(
                    activation,
                    len(tokenizer.word_index)+1,
                    max_length=256
                )

                model.compile(
                    optimizer=optimizers.Adam(1e-4),
                    loss='binary_crossentropy',
                    metrics=['accuracy']
                )

                start = time.time()
                history = model.fit(
                    x_train, y_train,
                    validation_data=(x_val, y_val),
                    epochs=epochs,
                    batch_size=batch_size,
                    callbacks=[early_stopping],
                    verbose=1
                )
                train_time = time.time() - start

                # Evaluation
                val_loss, val_acc = model.evaluate(x_val, y_val, verbose=0)
                dead_neurons = measure_neuron_activity(model, x_train)

                results.append({
                    'name': name,
                    'val_acc': val_acc,
                    'val_loss': val_loss,
                    'training_time': train_time,
                    'dead_neurons': dead_neurons,
                    'history': history.history
                })

                print(f"{name} | Val Acc: {val_acc:.4f} | Time: {train_time:.1f}s")

    except Exception as e:
        print(f"GPU Error: {e}, falling back to CPU...")
        # CPU training loop (same as above)

    return results

# ======================
# Visualization and Reporting (Same Structure)
# ======================
def visualize_results(results):
    activations = {
        'AdaptiveSSwishGELU': AdaptiveSSwishGELU(),
        'GELU_SSwish': GELU_SSwish(),
        'GELU': GELU(),
        'Swish': tf.nn.swish,
        'SSwish': SSwish(),
        'ReLU': tf.nn.relu,
        'Mish': Mish(),
    }
    plt.figure(figsize=(20, 15))

    # Validation Accuracy
    plt.subplot(2,2,1)
    for res in results:
        plt.plot(res['history']['val_accuracy'], label=res['name'])
    plt.title('Validation Accuracy')
    plt.legend()

    # Validation Loss
    plt.subplot(2,2,2)
    for res in results:
        plt.plot(res['history']['val_loss'], label=res['name'])
    plt.title('Validation Loss')
    plt.legend()

    # Training Time
    plt.subplot(2,2,3)
    times = [res['training_time'] for res in results]
    plt.bar(activations.keys(), times)
    plt.title('Training Time')

    # Dead Neurons
    plt.subplot(2,2,4)
    dead = [res['dead_neurons'] for res in results]
    plt.bar(activations.keys(), dead)
    plt.title('Dead Neurons Percentage')

    plt.tight_layout()
    plt.savefig('bert_activation_comparison.png')
    plt.show()

def output_comparison_table(results):
    df = pd.DataFrame([{
        'Activation': r['name'],
        'Val Accuracy': f"{r['val_acc']:.4f}",
        'Val Loss': f"{r['val_loss']:.4f}",
        'Training Time (s)': f"{r['training_time']:.1f}",
        'Dead Neurons (%)': f"{r['dead_neurons']:.2%}"
    } for r in results])

    return df.sort_values('Val Accuracy', ascending=False)

# ======================
# Main Execution
# ======================
if __name__ == "__main__":
    np.random.seed(42)
    tf.random.set_seed(42)

    results = run_experiment()
    visualize_results(results)
    df = output_comparison_table(results)

    print("\n=== Final Results ===")
    print(df)
    df.to_csv('bert_activation_results.csv', index=False)