In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import time
import pandas as pd
from tensorflow.keras import layers, models, constraints, optimizers
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score
import tensorflow_datasets as tfds

# ======================
# Custom Constraints
# ======================
class ClipConstraint(constraints.Constraint):
    """Clips weights between `min_val` and `max_val`."""
    def __init__(self, min_val, max_val):
        self.min_val = min_val
        self.max_val = max_val

    def __call__(self, w):
        return tf.clip_by_value(w, self.min_val, self.max_val)

    def get_config(self):
        return {'min_val': self.min_val, 'max_val': self.max_val}

# ======================
# Custom Activation Layers
# ======================
class AdaptiveSSwishGELU(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_shape):
        # Learnable parameters
        self.alpha = self.add_weight(
            name="alpha",
            shape=(1,),
            initializer="zeros",  # Start with GELU (alpha=0)
            constraint=tf.keras.constraints.MinMaxNorm(0.0, 1.0)
        )

        self.beta = self.add_weight(
            name="beta",
            shape=(1,),
            initializer="glorot_uniform",
            constraint=ClipConstraint(0.1, 10.0)
        )

        self.gamma = self.add_weight(
            name="gamma",
            shape=(1,),
            initializer="zeros"
        )

        super().build(input_shape)

    def call(self, x):
        # GELU approximation: x * σ(1.702x)
        gelu_term = x * tf.nn.sigmoid(1.702 * x)
        # SSwish: x * σ(beta x) - gamma
        sswish_term = x * tf.nn.sigmoid(self.beta * x) - self.gamma
        # Blend using alpha (sigmoid to enforce [0,1])
        alpha = tf.nn.sigmoid(self.alpha)  # α ∈ [0,1]
        return alpha * sswish_term + (1 - alpha) * gelu_term

class SSwish(layers.Layer):
    """Symmetric Swish"""
    def build(self, input_shape):
        self.beta = self.add_weight(
            name='beta',
            shape=(1,),
            initializer='glorot_uniform',
            constraint=ClipConstraint(0.1, 10)
        )
        self.gamma = self.add_weight(
            name='gamma',
            shape=(1,),
            initializer='zeros'
        )
        super().build(input_shape)

    def call(self, x):
        return x * tf.nn.sigmoid(self.beta * x) - self.gamma

class GELU(layers.Layer):
    """Gaussian Error Linear Unit"""
    def call(self, x):
        return tf.nn.gelu(x)

class Mish(layers.Layer):
    """Mish Activation"""
    def call(self, x):
        return x * tf.math.tanh(tf.math.softplus(x))


# ======================
# Load and Preprocess WMT 2014 English-German Dataset
# ======================
def load_wmt_data(max_samples=50000, max_length=50):
    """Load a subset of the WMT 2014 English-German dataset."""
    print("Loading WMT 2014 English-German dataset...")
    
    # Load the dataset
    dataset, info = tfds.load('wmt14_translate/de-en', with_info=True, as_supervised=True)
    train_dataset = dataset['train']
    
    # Take a subset for faster training
    train_dataset = train_dataset.take(max_samples)
    
    # Create lists to store data
    en_texts = []
    de_texts = []
    
    # Extract texts
    for en_text, de_text in tfds.as_numpy(train_dataset):
        en_texts.append(en_text.decode('utf-8'))
        de_texts.append(de_text.decode('utf-8'))
    
    # Add start and end tokens to target sequences
    de_texts_processed = ['<start> ' + text + ' <end>' for text in de_texts]
    
    # Create tokenizers
    en_tokenizer = Tokenizer(filters='')
    en_tokenizer.fit_on_texts(en_texts)
    
    de_tokenizer = Tokenizer(filters='')
    de_tokenizer.fit_on_texts(de_texts_processed)
    
    # Convert to sequences
    en_seqs = en_tokenizer.texts_to_sequences(en_texts)
    de_seqs = de_tokenizer.texts_to_sequences(de_texts_processed)
    
    # Pad sequences
    en_pad = pad_sequences(en_seqs, maxlen=max_length, padding='post')
    de_pad = pad_sequences(de_seqs, maxlen=max_length, padding='post')
    
    # Split into training and validation sets (90/10)
    split = int(len(en_pad) * 0.9)
    
    x_train, x_val = en_pad[:split], en_pad[split:]
    y_train, y_val = de_pad[:split], de_pad[split:]
    
    # Get vocabulary sizes
    en_vocab_size = len(en_tokenizer.word_index) + 1
    de_vocab_size = len(de_tokenizer.word_index) + 1
    
    print(f"English vocabulary size: {en_vocab_size}")
    print(f"German vocabulary size: {de_vocab_size}")
    print(f"Training samples: {len(x_train)}")
    print(f"Validation samples: {len(x_val)}")
    
    return (x_train, y_train), (x_val, y_val), en_tokenizer, de_tokenizer

# ======================
# Transformer Components
# ======================
def positional_encoding(length, depth):
    """Create positional encoding for transformer."""
    depth = depth/2
    
    positions = np.arange(length)[:, np.newaxis]
    depths = np.arange(depth)[np.newaxis, :]/depth
    
    angle_rates = 1 / (10000**depths)
    angle_rads = positions * angle_rates
    
    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1)
    
    return tf.cast(pos_encoding, dtype=tf.float32)

class PositionalEmbedding(layers.Layer):
    """Combines embedding with positional encoding."""
    def __init__(self, vocab_size, d_model, max_length=50):
        super().__init__()
        self.d_model = d_model
        self.embedding = layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.pos_encoding = positional_encoding(max_length, d_model)
        
    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[tf.newaxis, :length, :]
        return x

def encoder_block(inputs, embed_dim, num_heads, ff_dim, activation, rate=0.1):
    """Transformer encoder block with configurable activation function."""
    # Multi-head attention
    attention_output = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim
    )(inputs, inputs)
    attention_output = layers.Dropout(rate)(attention_output)
    out1 = layers.LayerNormalization(epsilon=1e-6)(inputs + attention_output)

    # Feed Forward Network
    if isinstance(activation, layers.Layer):
        # For custom activation layers (which are layers themselves)
        ffn_output = layers.Dense(ff_dim)(out1)
        ffn_output = activation(ffn_output)
        ffn_output = layers.Dense(embed_dim)(ffn_output)
    else:
        # For built-in activation functions
        ffn_output = layers.Dense(ff_dim, activation=activation)(out1)
        ffn_output = layers.Dense(embed_dim)(ffn_output)

    ffn_output = layers.Dropout(rate)(ffn_output)
    return layers.LayerNormalization(epsilon=1e-6)(out1 + ffn_output)

def decoder_block(inputs, context, embed_dim, num_heads, ff_dim, activation, rate=0.1):
    """Transformer decoder block with configurable activation function."""
    # Self-attention
    self_attention = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim
    )(inputs, inputs)
    self_attention = layers.Dropout(rate)(self_attention)
    out1 = layers.LayerNormalization(epsilon=1e-6)(inputs + self_attention)
    
    # Cross-attention
    cross_attention = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim
    )(out1, context)
    cross_attention = layers.Dropout(rate)(cross_attention)
    out2 = layers.LayerNormalization(epsilon=1e-6)(out1 + cross_attention)
    
    # Feed Forward Network
    if isinstance(activation, layers.Layer):
        # For custom activation layers (which are layers themselves)
        ffn_output = layers.Dense(ff_dim)(out2)
        ffn_output = activation(ffn_output)
        ffn_output = layers.Dense(embed_dim)(ffn_output)
    else:
        # For built-in activation functions
        ffn_output = layers.Dense(ff_dim, activation=activation)(out2)
        ffn_output = layers.Dense(embed_dim)(ffn_output)
        
    ffn_output = layers.Dropout(rate)(ffn_output)
    return layers.LayerNormalization(epsilon=1e-6)(out2 + ffn_output)

# ======================
# Define Model Architecture
# ======================
def create_translation_model(activation, en_vocab_size, de_vocab_size, d_model=64, num_heads=2, ff_dim=128, num_layers=2):
    """
    Create a small transformer model for translation with the specified activation function.
    """
    # Encoder
    encoder_inputs = layers.Input(shape=(None,), name="encoder_inputs")
    encoder_embedding = PositionalEmbedding(en_vocab_size, d_model)(encoder_inputs)
    encoder_outputs = encoder_embedding
    
    for i in range(num_layers):
        encoder_outputs = encoder_block(
            encoder_outputs, d_model, num_heads, ff_dim, activation
        )
    
    # Decoder
    decoder_inputs = layers.Input(shape=(None,), name="decoder_inputs")
    decoder_embedding = PositionalEmbedding(de_vocab_size, d_model)(decoder_inputs)
    decoder_outputs = decoder_embedding
    
    for i in range(num_layers):
        decoder_outputs = decoder_block(
            decoder_outputs, encoder_outputs, d_model, num_heads, ff_dim, activation
        )
    
    # Final output layer
    if isinstance(activation, layers.Layer):
        decoder_outputs = layers.Dense(ff_dim)(decoder_outputs)
        decoder_outputs = activation(decoder_outputs)
    else:
        decoder_outputs = layers.Dense(ff_dim, activation=activation)(decoder_outputs)
        
    outputs = layers.Dense(de_vocab_size, activation="softmax")(decoder_outputs)
    
    return models.Model(
        inputs=[encoder_inputs, decoder_inputs],
        outputs=outputs
    )

# ======================
# Training and Evaluation Functions
# ======================
def prepare_data_for_training(x, y):
    """
    Prepare data for training by creating decoder inputs and targets
    """
    # Decoder inputs (shifted right)
    decoder_inputs = y[:, :-1]
    # Targets (shifted left)
    decoder_targets = y[:, 1:]
    
    return [x, decoder_inputs], decoder_targets

def bleu_score(y_true, y_pred, tokenizer):
    """
    Calculate BLEU score for translation quality evaluation
    """
    from nltk.translate.bleu_score import corpus_bleu
    
    # Convert predictions to word sequences
    y_pred_words = []
    for pred in y_pred:
        # Get most likely token at each position
        indices = np.argmax(pred, axis=-1)
        # Convert indices to words
        words = [tokenizer.index_word.get(i, '') for i in indices if i > 0]
        y_pred_words.append(words)
    
    # Convert true values to word sequences
    y_true_words = []
    for true in y_true:
        # Convert indices to words, skipping padding (0)
        words = [[tokenizer.index_word.get(i, '') for i in true if i > 0]]
        y_true_words.append(words)
    
    # Calculate BLEU score
    return corpus_bleu(y_true_words, y_pred_words)

# ======================
# Main Experiment
# ======================
def run_experiment():
    # Load and preprocess data
    (x_train, y_train), (x_val, y_val), en_tokenizer, de_tokenizer = load_wmt_data()
    
    # Prepare data for training
    train_data, train_targets = prepare_data_for_training(x_train, y_train)
    val_data, val_targets = prepare_data_for_training(x_val, y_val)
    
    # Define activations to compare
    activations = {
        'AdaptiveSSwishGELU': AdaptiveSSwishGELU(),
        'GELU': GELU(),
        'Swish': tf.nn.swish,
        'SSwish': SSwish(),
        'ReLU': tf.nn.relu,
        'Mish': Mish(),
    }

    # Training parameters
    epochs = 10
    batch_size = 64
    results = []

    # Early stopping callback
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=2,
        restore_best_weights=True
    )

    # For measuring neuron "deadness" in transformers
    def measure_neuron_activity(model, x_data, layer_idx=-3):
        # Create a submodel to extract intermediate activations
        layer_outputs = [layer.output for layer in model.layers[:layer_idx]]
        # Create an intermediary model to get these activations
        activations_model = tf.keras.Model(inputs=model.inputs, outputs=layer_outputs)
        
        # Get sample input data
        sample_encoder_input = x_data[0][:100]
        sample_decoder_input = x_data[1][:100]
        
        # Get activations
        activations = activations_model.predict([sample_encoder_input, sample_decoder_input])
        penultimate_activations = activations[-1]  # Get the last layer
        
        # Calculate percentage of dead neurons
        if len(penultimate_activations.shape) > 2:
            # For tensors with more than 2 dimensions, flatten the spatial dimensions
            dead_neuron_pct = np.mean(np.max(penultimate_activations, axis=(0, 1)) < 1e-3)
        else:
            # For 2D tensors (batch_size, features)
            dead_neuron_pct = np.mean(np.max(penultimate_activations, axis=0) < 1e-3)
        
        return dead_neuron_pct

    try:
        with tf.device('/GPU:0'):  # Use GPU if available
            for name, activation in activations.items():
                print(f"\nTraining with {name} activation...")

                # Create model
                en_vocab_size = len(en_tokenizer.word_index) + 1
                de_vocab_size = len(de_tokenizer.word_index) + 1
                
                model = create_translation_model(
                    activation,
                    en_vocab_size,
                    de_vocab_size,
                    d_model=64,
                    num_heads=2,
                    ff_dim=128,
                    num_layers=2
                )

                # Compile model
                model.compile(
                    optimizer=optimizers.Adam(learning_rate=1e-4),
                    loss="sparse_categorical_crossentropy",
                    metrics=["accuracy"]
                )

                # Train model
                start_time = time.time()
                history = model.fit(
                    train_data, train_targets,
                    epochs=epochs,
                    batch_size=batch_size,
                    validation_data=(val_data, val_targets),
                    callbacks=[early_stopping],
                    verbose=1
                )
                training_time = time.time() - start_time

                # Evaluate model
                val_loss, val_acc = model.evaluate(val_data, val_targets, verbose=0)
                
                # Generate predictions for a small subset for BLEU score
                sample_size = min(100, len(val_data[0]))
                predictions = model.predict([val_data[0][:sample_size], val_data[1][:sample_size]])
                
                # Calculate BLEU score
                bleu = bleu_score(val_targets[:sample_size], predictions, de_tokenizer)

                # Measure dead neurons
                dead_neurons = measure_neuron_activity(model, train_data)

                # Store results
                results.append({
                    'name': name,
                    'val_acc': val_acc,
                    'val_loss': val_loss,
                    'bleu': bleu,
                    'training_time': training_time,
                    'dead_neurons': dead_neurons,
                    'history': history.history
                })

                print(f"{name} - Val Acc: {val_acc:.4f}, Val Loss: {val_loss:.4f}, BLEU: {bleu:.4f}, Time: {training_time:.1f}s, Dead Neurons: {dead_neurons:.2%}")
    except Exception as e:
        print(f"Error using GPU: {e}")
        print("Falling back to CPU...")
        
        # Repeat with CPU
        for name, activation in activations.items():
            print(f"\nTraining with {name} activation...")

            # Create model
            en_vocab_size = len(en_tokenizer.word_index) + 1
            de_vocab_size = len(de_tokenizer.word_index) + 1
            
            model = create_translation_model(
                activation,
                en_vocab_size,
                de_vocab_size,
                d_model=64,
                num_heads=2,
                ff_dim=128,
                num_layers=2
            )

            # Compile model
            model.compile(
                optimizer=optimizers.Adam(learning_rate=1e-4),
                loss="sparse_categorical_crossentropy",
                metrics=["accuracy"]
            )

            # Train model
            start_time = time.time()
            history = model.fit(
                train_data, train_targets,
                epochs=epochs,
                batch_size=batch_size,
                validation_data=(val_data, val_targets),
                callbacks=[early_stopping],
                verbose=1
            )
            training_time = time.time() - start_time

            # Evaluate model
            val_loss, val_acc = model.evaluate(val_data, val_targets, verbose=0)
            
            # Generate predictions for a small subset for BLEU score
            sample_size = min(100, len(val_data[0]))
            predictions = model.predict([val_data[0][:sample_size], val_data[1][:sample_size]])
            
            # Calculate BLEU score
            bleu = bleu_score(val_targets[:sample_size], predictions, de_tokenizer)

            # Measure dead neurons
            dead_neurons = measure_neuron_activity(model, train_data)

            # Store results
            results.append({
                'name': name,
                'val_acc': val_acc,
                'val_loss': val_loss,
                'bleu': bleu,
                'training_time': training_time,
                'dead_neurons': dead_neurons,
                'history': history.history
            })

            print(f"{name} - Val Acc: {val_acc:.4f}, Val Loss: {val_loss:.4f}, BLEU: {bleu:.4f}, Time: {training_time:.1f}s, Dead Neurons: {dead_neurons:.2%}")

    return results

# ======================
# Visualize Results
# ======================
def visualize_results(results):
    plt.figure(figsize=(20, 15))

    # 1. Validation Accuracy Plot
    plt.subplot(2, 2, 1)
    for res in results:
        if 'val_accuracy' in res['history']:
            plt.plot(res['history']['val_accuracy'], label=res['name'])
    plt.title('Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    # 2. Test Performance (Accuracy & BLEU)
    plt.subplot(2, 2, 2)
    names = [res['name'] for res in results]
    acc = [res['val_acc'] for res in results]
    bleu = [res['bleu'] * 100 for res in results]  # Scale BLEU score for better visualization

    x = np.arange(len(names))
    width = 0.35

    plt.bar(x - width/2, acc, width, label='Validation Accuracy')
    plt.bar(x + width/2, bleu, width, label='BLEU Score * 100')

    plt.title('Test Performance')
    plt.xlabel('Activation Function')
    plt.ylabel('Score')
    plt.xticks(x, names, rotation=45)
    plt.legend()

    # 3. Training Time
    plt.subplot(2, 2, 3)
    times = [res['training_time'] for res in results]
    plt.bar(names, times)
    plt.title('Training Time')
    plt.xlabel('Activation Function')
    plt.xticks(rotation=45)
    plt.ylabel('Seconds')

    # 4. Dead Neurons Percentage
    plt.subplot(2, 2, 4)
    dead = [res['dead_neurons'] for res in results]
    plt.bar(names, dead)
    plt.title('Dead Neurons Percentage')
    plt.xlabel('Activation Function')
    plt.xticks(rotation=45)
    plt.ylabel('Percentage')

    plt.tight_layout()
    plt.savefig('wmt_activation_comparison.png')
    plt.show()

# ======================
# Final Comparison
# ======================
def output_comparison_table(results):
    # Sort by validation accuracy
    sorted_results = sorted(results, key=lambda x: x['val_acc'], reverse=True)

    # Create DataFrame
    df = pd.DataFrame([
        {
            'Activation': res['name'],
            'Val Accuracy': f"{res['val_acc']:.4f}",
            'Val Loss': f"{res['val_loss']:.4f}",
            'BLEU Score': f"{res['bleu']:.4f}",
            'Training Time (s)': f"{res['training_time']:.1f}",
            'Dead Neurons (%)': f"{res['dead_neurons']:.2%}"
        }
        for res in sorted_results
    ])

    return df

# ======================
# Main Execution
# ======================
if __name__ == "__main__":
    # Set random seeds for reproducibility
    np.random.seed(42)
    tf.random.set_seed(42)
    
    # Ensure NLTK BLEU calculation is available
    import nltk
    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt')

    results = run_experiment()
    visualize_results(results)
    comparison_table = output_comparison_table(results)

    print("\nActivation Function Comparison (sorted by validation accuracy):")
    print(comparison_table)
    comparison_table.to_csv('wmt_activation_comparison.csv', index=False)