In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("TensorFlow version:", tf.__version__)
print("Attention Fusion Systems initialized!")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# 1. Multi-Head Attention Layer
class MultiHeadAttention(layers.Layer):
    """
    Multi-head attention mechanism for fusion with LSTMs
    """
    
    def __init__(self, d_model, num_heads, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.d_model = d_model
        
        assert d_model % self.num_heads == 0
        
        self.depth = d_model // self.num_heads
        
        self.wq = layers.Dense(d_model)
        self.wk = layers.Dense(d_model)
        self.wv = layers.Dense(d_model)
        
        self.dense = layers.Dense(d_model)
        
    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth)."""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, v, k, q, mask=None):
        batch_size = tf.shape(q)[0]
        
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        # Scaled dot-product attention
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        dk = tf.cast(tf.shape(k)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
        
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)
        
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        output = tf.matmul(attention_weights, v)
        
        output = tf.transpose(output, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(output, (batch_size, -1, self.d_model))
        
        output = self.dense(concat_attention)
        
        return output, attention_weights

# 2. Attention-LSTM Fusion Architectures
def create_attention_lstm_models(input_shape, lstm_units=64, num_classes=3):
    """
    Create different attention-LSTM fusion architectures
    """
    models = {}
    
    # Model 1: Post-LSTM Attention
    lstm_input = layers.Input(shape=input_shape)
    lstm_out = layers.LSTM(lstm_units, return_sequences=True)(lstm_input)
    
    # Self-attention on LSTM outputs
    attention_out = layers.MultiHeadAttention(
        num_heads=4, key_dim=lstm_units//4
    )(lstm_out, lstm_out)
    
    # Global pooling and classification
    pooled = layers.GlobalAveragePooling1D()(attention_out)
    dense = layers.Dense(32, activation='relu')(pooled)
    dropout = layers.Dropout(0.3)(dense)
    output = layers.Dense(num_classes, activation='softmax')(dropout)
    
    models['Post_LSTM_Attention'] = keras.Model(
        inputs=lstm_input, outputs=output, name='Post_LSTM_Attention'
    )
    
    # Model 2: Pre-LSTM Attention
    pre_input = layers.Input(shape=input_shape)
    
    # Self-attention on inputs
    pre_attention = layers.MultiHeadAttention(
        num_heads=4, key_dim=input_shape[-1]//4
    )(pre_input, pre_input)
    
    # LSTM processing
    lstm_out = layers.LSTM(lstm_units, return_sequences=False)(pre_attention)
    dense = layers.Dense(32, activation='relu')(lstm_out)
    dropout = layers.Dropout(0.3)(dense)
    output = layers.Dense(num_classes, activation='softmax')(dropout)
    
    models['Pre_LSTM_Attention'] = keras.Model(
        inputs=pre_input, outputs=output, name='Pre_LSTM_Attention'
    )
    
    # Model 3: Parallel Fusion
    parallel_input = layers.Input(shape=input_shape)
    
    # LSTM branch
    lstm_branch = layers.LSTM(lstm_units, return_sequences=True)(parallel_input)
    lstm_pooled = layers.GlobalAveragePooling1D()(lstm_branch)
    
    # Attention branch
    attention_branch = layers.MultiHeadAttention(
        num_heads=4, key_dim=input_shape[-1]//4
    )(parallel_input, parallel_input)
    attention_pooled = layers.GlobalAveragePooling1D()(attention_branch)
    
    # Fusion
    fused = layers.Concatenate()([lstm_pooled, attention_pooled])
    dense = layers.Dense(64, activation='relu')(fused)
    dropout = layers.Dropout(0.3)(dense)
    output = layers.Dense(num_classes, activation='softmax')(dropout)
    
    models['Parallel_Fusion'] = keras.Model(
        inputs=parallel_input, outputs=output, name='Parallel_Fusion'
    )
    
    # Model 4: Hierarchical Fusion
    hier_input = layers.Input(shape=input_shape)
    
    # First level: LSTM processing
    lstm_l1 = layers.LSTM(lstm_units, return_sequences=True)(hier_input)
    
    # Second level: Attention on LSTM outputs
    attention_l2 = layers.MultiHeadAttention(
        num_heads=4, key_dim=lstm_units//4
    )(lstm_l1, lstm_l1)
    
    # Third level: Another LSTM
    lstm_l3 = layers.LSTM(lstm_units//2, return_sequences=False)(attention_l2)
    
    dense = layers.Dense(32, activation='relu')(lstm_l3)
    dropout = layers.Dropout(0.3)(dense)
    output = layers.Dense(num_classes, activation='softmax')(dropout)
    
    models['Hierarchical_Fusion'] = keras.Model(
        inputs=hier_input, outputs=output, name='Hierarchical_Fusion'
    )
    
    return models

# 3. Attention Visualization System
class AttentionVisualizer:
    """
    Visualize attention patterns in fusion models
    """
    
    def __init__(self):
        self.attention_weights = {}
    
    def extract_attention_weights(self, model, input_data):
        """
        Extract attention weights from model
        """
        # Create a model that outputs attention weights
        attention_layers = []
        for layer in model.layers:
            if isinstance(layer, layers.MultiHeadAttention):
                attention_layers.append(layer)
        
        if not attention_layers:
            print("No attention layers found in model")
            return None
        
        # For demonstration, we'll simulate attention weights
        # In practice, you'd need to modify the model to return attention weights
        batch_size, seq_len, features = input_data.shape
        num_heads = 4
        
        # Simulate attention weights
        attention_weights = np.random.softmax(
            np.random.randn(batch_size, num_heads, seq_len, seq_len), axis=-1
        )
        
        return attention_weights
    
    def visualize_attention_patterns(self, attention_weights, input_sequence=None):
        """
        Visualize attention patterns
        """
        if attention_weights is None:
            print("No attention weights to visualize")
            return
        
        # Take first sample and average across heads
        sample_attention = attention_weights[0].mean(axis=0)
        
        plt.figure(figsize=(15, 10))
        
        # Attention heatmap
        plt.subplot(2, 3, 1)
        sns.heatmap(sample_attention, cmap='Blues', cbar=True)
        plt.title('Attention Weights Heatmap')
        plt.xlabel('Key Position')
        plt.ylabel('Query Position')
        
        # Attention distribution for each query
        plt.subplot(2, 3, 2)
        for i in range(0, sample_attention.shape[0], max(1, sample_attention.shape[0]//5)):
            plt.plot(sample_attention[i], label=f'Query {i}', alpha=0.7)
        plt.title('Attention Distribution by Query')
        plt.xlabel('Key Position')
        plt.ylabel('Attention Weight')
        plt.legend()
        
        # Average attention per position
        plt.subplot(2, 3, 3)
        avg_attention = sample_attention.mean(axis=0)
        plt.bar(range(len(avg_attention)), avg_attention, alpha=0.7)
        plt.title('Average Attention per Position')
        plt.xlabel('Position')
        plt.ylabel('Average Attention')
        
        # Attention head comparison (using original multi-head weights)
        plt.subplot(2, 3, 4)
        head_attentions = attention_weights[0]  # First sample, all heads
        head_means = [head.mean() for head in head_attentions]
        plt.bar(range(len(head_means)), head_means, alpha=0.7)
        plt.title('Average Attention by Head')
        plt.xlabel('Head Index')
        plt.ylabel('Average Attention')
        
        # Attention entropy (measure of focus)
        plt.subplot(2, 3, 5)
        entropies = []
        for i in range(sample_attention.shape[0]):
            probs = sample_attention[i]
            entropy = -np.sum(probs * np.log(probs + 1e-10))
            entropies.append(entropy)
        
        plt.plot(entropies, 'o-', alpha=0.7)
        plt.title('Attention Entropy by Query Position')
        plt.xlabel('Query Position')
        plt.ylabel('Entropy (Higher = More Spread)')
        
        # Attention matrix statistics
        plt.subplot(2, 3, 6)
        stats = {
            'Max': sample_attention.max(),
            'Min': sample_attention.min(),
            'Mean': sample_attention.mean(),
            'Std': sample_attention.std()
        }
        
        plt.bar(stats.keys(), stats.values(), alpha=0.7)
        plt.title('Attention Matrix Statistics')
        plt.ylabel('Value')
        
        plt.tight_layout()
        plt.show()

# Generate sample data for attention-LSTM fusion
def create_complex_sequence_data(num_samples=800, seq_length=40, features=16):
    """
    Create complex sequence data for attention-LSTM fusion testing
    """
    X = []
    y = []
    
    for i in range(num_samples):
        # Create sequences with different attention-requiring patterns
        if i % 3 == 0:
            # Pattern requiring early attention
            seq = np.random.randn(seq_length, features) * 0.5
            seq[:5] += 2.0  # Important information at beginning
            label = 0
        elif i % 3 == 1:
            # Pattern requiring late attention
            seq = np.random.randn(seq_length, features) * 0.5
            seq[-5:] += 2.0  # Important information at end
            label = 1
        else:
            # Pattern requiring distributed attention
            important_positions = np.random.choice(seq_length, 3, replace=False)
            seq = np.random.randn(seq_length, features) * 0.5
            seq[important_positions] += 1.5
            label = 2
        
        X.append(seq)
        y.append(label)
    
    return np.array(X), np.array(y)

# Main execution
print("Creating complex sequence data for attention-LSTM fusion...")
X, y = create_complex_sequence_data(num_samples=600, seq_length=30, features=12)

# Split data
split_idx = int(0.8 * len(X))
X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]

print(f"Training data shape: {X_train.shape}")
print(f"Validation data shape: {X_val.shape}")

# Create attention-LSTM fusion models
input_shape = X_train.shape[1:]
fusion_models = create_attention_lstm_models(input_shape, lstm_units=64, num_classes=3)

print(f"\nCreated {len(fusion_models)} attention-LSTM fusion models:")
for name, model in fusion_models.items():
    print(f"- {name}: {model.count_params():,} parameters")

# Train and evaluate models (simplified for demonstration)
print(f"\nTraining attention-LSTM fusion models...")
results = {}

for name, model in fusion_models.items():
    print(f"\nTraining {name}...")
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=10,
        batch_size=32,
        verbose=0
    )
    
    results[name] = {
        'model': model,
        'history': history.history,
        'final_val_acc': history.history['val_accuracy'][-1]
    }
    
    print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")

# Visualize attention patterns
print(f"\nGenerating attention visualizations...")
visualizer = AttentionVisualizer()

# Use one of the models for attention visualization
sample_model = fusion_models['Post_LSTM_Attention']
sample_data = X_val[:5]  # First 5 validation samples

attention_weights = visualizer.extract_attention_weights(sample_model, sample_data)
visualizer.visualize_attention_patterns(attention_weights)

# Performance comparison
plt.figure(figsize=(15, 8))

# Training curves
plt.subplot(2, 3, 1)
for name, result in results.items():
    plt.plot(result['history']['val_accuracy'], label=name, alpha=0.7)
plt.title('Validation Accuracy Curves')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)

# Final performance comparison
plt.subplot(2, 3, 2)
names = list(results.keys())
final_accs = [results[name]['final_val_acc'] for name in names]
bars = plt.bar(range(len(names)), final_accs, alpha=0.7)
plt.title('Final Model Performance')
plt.ylabel('Validation Accuracy')
plt.xticks(range(len(names)), names, rotation=45, ha='right')

for bar, acc in zip(bars, final_accs):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
             f'{acc:.3f}', ha='center', va='bottom')

# Parameter efficiency
plt.subplot(2, 3, 3)
param_counts = [fusion_models[name].count_params() for name in names]
efficiency = [acc / (params / 1000) for acc, params in zip(final_accs, param_counts)]
plt.bar(range(len(names)), efficiency, alpha=0.7, color='green')
plt.title('Parameter Efficiency')
plt.ylabel('Accuracy per 1K Parameters')
plt.xticks(range(len(names)), names, rotation=45, ha='right')

plt.tight_layout()
plt.show()

print(f"\nAttention-LSTM Fusion Analysis Summary:")
print("=" * 50)

best_model = max(results.keys(), key=lambda x: results[x]['final_val_acc'])
best_accuracy = results[best_model]['final_val_acc']

for name, result in results.items():
    acc = result['final_val_acc']
    params = fusion_models[name].count_params()
    print(f"\n{name}:")
    print(f"  Validation Accuracy: {acc:.4f}")
    print(f"  Parameters: {params:,}")

print(f"\nBest performing fusion model: {best_model} ({best_accuracy:.4f} accuracy)")

print(f"\nKey Insights from Attention-LSTM Fusion:")
print("- Attention mechanisms enhance LSTM capabilities for complex patterns")
print("- Different fusion strategies work better for different task types")
print("- Hierarchical fusion can capture multi-level dependencies")
print("- Attention provides interpretability into model decisions")

print(f"\nAttention Fusion Systems Complete!")
print(f"Advanced LSTM architectures with attention integration mastered!")
