In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print("TensorFlow version:", tf.__version__)
print("LSTM Architecture Deep Dive initialized!")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


In [None]:
# 1. Custom LSTM Implementation for Understanding
class CustomLSTMCell:
    """
    Custom LSTM cell implementation to understand the mathematics
    """
    
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Initialize weights (simplified initialization)
        self.W_f = np.random.randn(hidden_size, input_size + hidden_size) * 0.1  # Forget gate
        self.b_f = np.zeros((hidden_size, 1))
        
        self.W_i = np.random.randn(hidden_size, input_size + hidden_size) * 0.1  # Input gate
        self.b_i = np.zeros((hidden_size, 1))
        
        self.W_C = np.random.randn(hidden_size, input_size + hidden_size) * 0.1  # Candidate values
        self.b_C = np.zeros((hidden_size, 1))
        
        self.W_o = np.random.randn(hidden_size, input_size + hidden_size) * 0.1  # Output gate
        self.b_o = np.zeros((hidden_size, 1))
        
    def sigmoid(self, x):
        """Sigmoid activation function"""
        return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
    
    def tanh(self, x):
        """Tanh activation function"""
        return np.tanh(np.clip(x, -500, 500))
    
    def forward(self, x_t, h_prev, C_prev):
        """
        Forward pass through LSTM cell
        
        Args:
            x_t: Input at time t (input_size, 1)
            h_prev: Previous hidden state (hidden_size, 1)
            C_prev: Previous cell state (hidden_size, 1)
            
        Returns:
            h_t: New hidden state
            C_t: New cell state
            gates: Dictionary containing gate values for analysis
        """
        # Concatenate input and previous hidden state
        concat_input = np.vstack([h_prev, x_t])
        
        # Compute gates
        f_t = self.sigmoid(np.dot(self.W_f, concat_input) + self.b_f)  # Forget gate
        i_t = self.sigmoid(np.dot(self.W_i, concat_input) + self.b_i)  # Input gate
        C_tilde = self.tanh(np.dot(self.W_C, concat_input) + self.b_C)  # Candidate values
        o_t = self.sigmoid(np.dot(self.W_o, concat_input) + self.b_o)  # Output gate
        
        # Update cell state
        C_t = f_t * C_prev + i_t * C_tilde
        
        # Compute hidden state
        h_t = o_t * self.tanh(C_t)
        
        # Store gate values for analysis
        gates = {
            'forget_gate': f_t,
            'input_gate': i_t,
            'candidate_values': C_tilde,
            'output_gate': o_t
        }
        
        return h_t, C_t, gates

# 2. LSTM Analysis and Visualization
class LSTMAnalyzer:
    """
    Analyze LSTM behavior and visualize internal states
    """
    
    def __init__(self, lstm_cell):
        self.lstm_cell = lstm_cell
        self.states_history = []
        
    def run_sequence(self, input_sequence):
        """
        Run LSTM on a sequence and record all states
        """
        sequence_length = len(input_sequence)
        hidden_size = self.lstm_cell.hidden_size
        
        # Initialize states
        h_t = np.zeros((hidden_size, 1))
        C_t = np.zeros((hidden_size, 1))
        
        states_history = []
        
        for t, x_t in enumerate(input_sequence):
            x_t = x_t.reshape(-1, 1)  # Ensure correct shape
            h_t, C_t, gates = self.lstm_cell.forward(x_t, h_t, C_t)
            
            state_info = {
                'timestep': t,
                'input': x_t.flatten(),
                'hidden_state': h_t.flatten(),
                'cell_state': C_t.flatten(),
                'forget_gate': gates['forget_gate'].flatten(),
                'input_gate': gates['input_gate'].flatten(),
                'candidate_values': gates['candidate_values'].flatten(),
                'output_gate': gates['output_gate'].flatten()
            }
            
            states_history.append(state_info)
        
        self.states_history = states_history
        return states_history
    
    def visualize_lstm_dynamics(self):
        """
        Create comprehensive visualization of LSTM internal dynamics
        """
        if not self.states_history:
            print("No states to visualize. Run a sequence first.")
            return
        
        # Extract data for visualization
        timesteps = [state['timestep'] for state in self.states_history]
        inputs = np.array([state['input'] for state in self.states_history])
        hidden_states = np.array([state['hidden_state'] for state in self.states_history])
        cell_states = np.array([state['cell_state'] for state in self.states_history])
        forget_gates = np.array([state['forget_gate'] for state in self.states_history])
        input_gates = np.array([state['input_gate'] for state in self.states_history])
        candidate_values = np.array([state['candidate_values'] for state in self.states_history])
        output_gates = np.array([state['output_gate'] for state in self.states_history])
        
        # Create visualization
        fig, axes = plt.subplots(3, 3, figsize=(18, 15))
        
        # Input sequence
        axes[0, 0].plot(timesteps, inputs.mean(axis=1), 'b-', linewidth=2)
        axes[0, 0].set_title('Input Sequence')
        axes[0, 0].set_xlabel('Time Step')
        axes[0, 0].set_ylabel('Input Value')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Hidden state evolution
        for i in range(min(3, hidden_states.shape[1])):
            axes[0, 1].plot(timesteps, hidden_states[:, i], label=f'h_{i}', alpha=0.7)
        axes[0, 1].set_title('Hidden State Evolution')
        axes[0, 1].set_xlabel('Time Step')
        axes[0, 1].set_ylabel('Hidden State Value')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Cell state evolution
        for i in range(min(3, cell_states.shape[1])):
            axes[0, 2].plot(timesteps, cell_states[:, i], label=f'C_{i}', alpha=0.7)
        axes[0, 2].set_title('Cell State Evolution')
        axes[0, 2].set_xlabel('Time Step')
        axes[0, 2].set_ylabel('Cell State Value')
        axes[0, 2].legend()
        axes[0, 2].grid(True, alpha=0.3)
        
        # Forget gate heatmap
        im1 = axes[1, 0].imshow(forget_gates.T, aspect='auto', cmap='RdYlBu_r', interpolation='nearest')
        axes[1, 0].set_title('Forget Gate Activations')
        axes[1, 0].set_xlabel('Time Step')
        axes[1, 0].set_ylabel('Hidden Unit')
        plt.colorbar(im1, ax=axes[1, 0])
        
        # Input gate heatmap
        im2 = axes[1, 1].imshow(input_gates.T, aspect='auto', cmap='RdYlBu_r', interpolation='nearest')
        axes[1, 1].set_title('Input Gate Activations')
        axes[1, 1].set_xlabel('Time Step')
        axes[1, 1].set_ylabel('Hidden Unit')
        plt.colorbar(im2, ax=axes[1, 1])
        
        # Output gate heatmap
        im3 = axes[1, 2].imshow(output_gates.T, aspect='auto', cmap='RdYlBu_r', interpolation='nearest')
        axes[1, 2].set_title('Output Gate Activations')
        axes[1, 2].set_xlabel('Time Step')
        axes[1, 2].set_ylabel('Hidden Unit')
        plt.colorbar(im3, ax=axes[1, 2])
        
        # Gate activation distributions
        all_forget = forget_gates.flatten()
        all_input = input_gates.flatten()
        all_output = output_gates.flatten()
        
        axes[2, 0].hist([all_forget, all_input, all_output], bins=30, alpha=0.7, 
                       label=['Forget', 'Input', 'Output'], color=['red', 'blue', 'green'])
        axes[2, 0].set_title('Gate Activation Distributions')
        axes[2, 0].set_xlabel('Activation Value')
        axes[2, 0].set_ylabel('Frequency')
        axes[2, 0].legend()
        
        # Information flow analysis
        forget_avg = forget_gates.mean(axis=1)
        input_avg = input_gates.mean(axis=1)
        output_avg = output_gates.mean(axis=1)
        
        axes[2, 1].plot(timesteps, forget_avg, 'r-', label='Forget Gate', alpha=0.7)
        axes[2, 1].plot(timesteps, input_avg, 'b-', label='Input Gate', alpha=0.7)
        axes[2, 1].plot(timesteps, output_avg, 'g-', label='Output Gate', alpha=0.7)
        axes[2, 1].set_title('Average Gate Activations Over Time')
        axes[2, 1].set_xlabel('Time Step')
        axes[2, 1].set_ylabel('Average Activation')
        axes[2, 1].legend()
        axes[2, 1].grid(True, alpha=0.3)
        
        # Memory retention analysis
        cell_magnitude = np.linalg.norm(cell_states, axis=1)
        hidden_magnitude = np.linalg.norm(hidden_states, axis=1)
        
        axes[2, 2].plot(timesteps, cell_magnitude, 'purple', label='Cell State Magnitude', alpha=0.7)
        axes[2, 2].plot(timesteps, hidden_magnitude, 'orange', label='Hidden State Magnitude', alpha=0.7)
        axes[2, 2].set_title('State Magnitude Evolution')
        axes[2, 2].set_xlabel('Time Step')
        axes[2, 2].set_ylabel('L2 Norm')
        axes[2, 2].legend()
        axes[2, 2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# 3. LSTM vs SimpleRNN Comparison
def create_comparison_models(input_size, hidden_size, sequence_length):
    """
    Create LSTM and SimpleRNN models for comparison
    """
    # LSTM model
    lstm_model = keras.Sequential([
        layers.LSTM(hidden_size, return_sequences=True, return_state=False),
        layers.Dense(input_size, activation='linear')
    ], name='LSTM_Model')
    
    # SimpleRNN model  
    rnn_model = keras.Sequential([
        layers.SimpleRNN(hidden_size, return_sequences=True, return_state=False),
        layers.Dense(input_size, activation='linear')
    ], name='SimpleRNN_Model')
    
    # Compile models
    lstm_model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    rnn_model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    
    return lstm_model, rnn_model

# Initialize custom LSTM for analysis
input_size = 3
hidden_size = 8
custom_lstm = CustomLSTMCell(input_size, hidden_size)
analyzer = LSTMAnalyzer(custom_lstm)

# Create test sequence
test_sequence = [
    np.array([1.0, 0.5, 0.2]),   # t=0
    np.array([0.8, 1.0, 0.1]),   # t=1  
    np.array([0.3, 0.7, 0.9]),   # t=2
    np.array([0.1, 0.2, 0.8]),   # t=3
    np.array([0.9, 0.1, 0.3]),   # t=4
    np.array([0.5, 0.8, 0.4]),   # t=5
    np.array([0.2, 0.9, 0.7]),   # t=6
    np.array([0.7, 0.3, 0.6])    # t=7
]

print("Running LSTM analysis on test sequence...")
states_history = analyzer.run_sequence(test_sequence)

print(f"Analyzed {len(states_history)} time steps")
print("Generating LSTM dynamics visualization...")

# Visualize LSTM dynamics
analyzer.visualize_lstm_dynamics()

# Print some statistics
print(f"\nLSTM Analysis Summary:")
print("=" * 40)

final_state = states_history[-1]
print(f"Final hidden state range: [{final_state['hidden_state'].min():.3f}, {final_state['hidden_state'].max():.3f}]")
print(f"Final cell state range: [{final_state['cell_state'].min():.3f}, {final_state['cell_state'].max():.3f}]")

# Gate activation statistics
all_forget = np.array([state['forget_gate'] for state in states_history])
all_input = np.array([state['input_gate'] for state in states_history])
all_output = np.array([state['output_gate'] for state in states_history])

print(f"\nGate Activation Statistics:")
print(f"Forget gate - Mean: {all_forget.mean():.3f}, Std: {all_forget.std():.3f}")
print(f"Input gate - Mean: {all_input.mean():.3f}, Std: {all_input.std():.3f}")
print(f"Output gate - Mean: {all_output.mean():.3f}, Std: {all_output.std():.3f}")

print(f"\nLSTM Architecture Deep Dive Complete!")
print(f"Understanding of gate mechanisms and cell dynamics achieved!")
print(f"Ready for LSTM variations and improvements!")
