# Simple PyTorch Neural Network Implementations

## 🎯 **Exam Preparation - Core Neural Network Concepts**

This notebook contains **simple, clean implementations** of essential neural network architectures using PyTorch. Perfect for exam revision and understanding core concepts.

### **📚 What's Covered:**
1. **🧠 Artificial Neural Network (ANN)** - Basic feedforward networks
2. **🖼️ Convolutional Neural Network (CNN)** - Image processing networks  
3. **🔄 Recurrent Neural Network (RNN)** - Sequential data processing
4. **🧮 Long Short-Term Memory (LSTM)** - Advanced sequence modeling
5. **⚡ Gated Recurrent Unit (GRU)** - Efficient sequence processing
6. **🏋️ Training & Evaluation** - Complete training pipelines

### **🎓 Learning Objectives:**
- Understand **architecture fundamentals** of each network type
- Learn **PyTorch implementation patterns**
- Master **forward pass mechanics**
- Practice **loss calculation and backpropagation**
- Compare **performance characteristics**

---

In [None]:
# Import Required Libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split

# Data handling and visualization
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_classification, make_regression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, mean_squared_error

# Utilities
import time
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("✅ All libraries imported successfully!")
print(f"📦 PyTorch version: {torch.__version__}")
print(f"🎯 Device: {device}")

## 📊 Sample Datasets

We'll create simple synthetic datasets to test our neural network implementations:

In [None]:
# Sample Datasets for Testing Neural Networks

def create_classification_data(n_samples=1000, n_features=10, n_classes=3):
    """Create synthetic classification dataset for ANN."""
    X, y = make_classification(
        n_samples=n_samples,
        n_features=n_features,
        n_classes=n_classes,
        n_informative=n_features//2,
        n_redundant=0,
        n_clusters_per_class=1,
        random_state=42
    )
    
    # Normalize features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    return torch.FloatTensor(X), torch.LongTensor(y)

def create_regression_data(n_samples=1000, n_features=5):
    """Create synthetic regression dataset."""
    X, y = make_regression(
        n_samples=n_samples,
        n_features=n_features,
        noise=0.1,
        random_state=42
    )
    
    # Normalize features and targets
    scaler_X = StandardScaler()
    scaler_y = StandardScaler()
    
    X = scaler_X.fit_transform(X)
    y = scaler_y.fit_transform(y.reshape(-1, 1)).flatten()
    
    return torch.FloatTensor(X), torch.FloatTensor(y)

def create_image_data(n_samples=1000, img_size=28):
    """Create synthetic image data for CNN (simplified MNIST-like)."""
    # Create simple geometric patterns
    X = torch.randn(n_samples, 1, img_size, img_size)
    y = torch.randint(0, 3, (n_samples,))
    
    # Add some pattern structure
    for i in range(n_samples):
        if y[i] == 0:  # Circles
            center = img_size // 2
            radius = img_size // 4
            for r in range(img_size):
                for c in range(img_size):
                    if (r - center)**2 + (c - center)**2 <= radius**2:
                        X[i, 0, r, c] = 1.0
        elif y[i] == 1:  # Horizontal lines
            start_row = img_size // 3
            end_row = 2 * img_size // 3
            X[i, 0, start_row:end_row, :] = 1.0
        else:  # Diagonal lines
            for k in range(img_size):
                if k < img_size:
                    X[i, 0, k, k] = 1.0
                if img_size - 1 - k >= 0:
                    X[i, 0, k, img_size - 1 - k] = 1.0
    
    return X, y

def create_sequence_data(n_samples=1000, seq_len=20, n_features=5):
    """Create synthetic sequence data for RNN/LSTM/GRU."""
    X = torch.randn(n_samples, seq_len, n_features)
    
    # Create targets based on sequence patterns
    # Class 0: Increasing trend, Class 1: Decreasing trend, Class 2: Oscillating
    y = torch.zeros(n_samples, dtype=torch.long)
    
    for i in range(n_samples):
        # Add trend to first feature
        if i % 3 == 0:  # Increasing
            trend = torch.linspace(0, 2, seq_len)
            X[i, :, 0] += trend
            y[i] = 0
        elif i % 3 == 1:  # Decreasing  
            trend = torch.linspace(2, 0, seq_len)
            X[i, :, 0] += trend
            y[i] = 1
        else:  # Oscillating
            trend = torch.sin(torch.linspace(0, 4*np.pi, seq_len))
            X[i, :, 0] += trend
            y[i] = 2
    
    return X, y

# Create all datasets
print("🏗️ Creating datasets...")

# 1. Classification data for ANN
X_clf, y_clf = create_classification_data(n_samples=1000, n_features=10, n_classes=3)
print(f"📈 Classification dataset: {X_clf.shape}, {y_clf.shape}")

# 2. Regression data
X_reg, y_reg = create_regression_data(n_samples=1000, n_features=5)
print(f"📊 Regression dataset: {X_reg.shape}, {y_reg.shape}")

# 3. Image data for CNN
X_img, y_img = create_image_data(n_samples=800, img_size=28)
print(f"🖼️ Image dataset: {X_img.shape}, {y_img.shape}")

# 4. Sequence data for RNN/LSTM/GRU
X_seq, y_seq = create_sequence_data(n_samples=800, seq_len=20, n_features=5)
print(f"🔄 Sequence dataset: {X_seq.shape}, {y_seq.shape}")

print("\n✅ All datasets created successfully!")

# Visualize sample data
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Classification data
axes[0,0].scatter(X_clf[:, 0], X_clf[:, 1], c=y_clf, alpha=0.7)
axes[0,0].set_title('Classification Data (First 2 Features)')
axes[0,0].set_xlabel('Feature 1')
axes[0,0].set_ylabel('Feature 2')

# Regression data
axes[0,1].scatter(X_reg[:, 0], y_reg, alpha=0.7)
axes[0,1].set_title('Regression Data')
axes[0,1].set_xlabel('Feature 1')
axes[0,1].set_ylabel('Target')

# Sample image
sample_img = X_img[0, 0].numpy()
axes[1,0].imshow(sample_img, cmap='gray')
axes[1,0].set_title(f'Sample Image (Class {y_img[0]})')
axes[1,0].axis('off')

# Sample sequence
sample_seq = X_seq[0, :, 0].numpy()
axes[1,1].plot(sample_seq)
axes[1,1].set_title(f'Sample Sequence (Class {y_seq[0]})')
axes[1,1].set_xlabel('Time Steps')
axes[1,1].set_ylabel('Feature Value')

plt.tight_layout()
plt.show()

## 🧠 1. Simple Artificial Neural Network (ANN)

**Key Concepts:**
- **Feedforward architecture** with fully connected layers
- **Activation functions** (ReLU, Sigmoid, Tanh)
- **Backpropagation** for weight updates
- **Gradient descent optimization**

**Architecture:**
```
Input Layer → Hidden Layer(s) → Output Layer
     ↓              ↓              ↓
  Linear         Linear        Linear
Transform    + Activation    + Output
```

In [None]:
# Simple Artificial Neural Network Implementation

class SimpleANN(nn.Module):
    """
    Simple Artificial Neural Network for classification.
    
    Architecture: Input → Hidden → Output
    """
    
    def __init__(self, input_size, hidden_size, output_size, activation='relu'):
        super(SimpleANN, self).__init__()
        
        # Store architecture info
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size)   # Input → Hidden
        self.fc2 = nn.Linear(hidden_size, output_size)  # Hidden → Output
        
        # Choose activation function
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            self.activation = nn.ReLU()  # Default
    
    def forward(self, x):
        """
        Forward pass through the network.
        
        Args:
            x: Input tensor (batch_size, input_size)
        
        Returns:
            Output tensor (batch_size, output_size)
        """
        # Input → Hidden (with activation)
        hidden = self.fc1(x)          # Linear transformation
        hidden = self.activation(hidden)  # Non-linear activation
        
        # Hidden → Output (no activation for classification)
        output = self.fc2(hidden)     # Final linear transformation
        
        return output
    
    def get_info(self):
        """Return model information."""
        total_params = sum(p.numel() for p in self.parameters())
        return {
            'Architecture': 'Feedforward Neural Network',
            'Input Size': self.input_size,
            'Hidden Size': self.hidden_size,
            'Output Size': self.output_size,
            'Total Parameters': total_params,
            'Activation': str(self.activation)
        }

# Create and test ANN
print("🧠 Simple ANN Implementation")
print("=" * 40)

# Model parameters
input_size = X_clf.shape[1]  # Number of features
hidden_size = 64
output_size = len(torch.unique(y_clf))  # Number of classes

# Create model
ann_model = SimpleANN(input_size, hidden_size, output_size, activation='relu')

# Display model info
model_info = ann_model.get_info()
for key, value in model_info.items():
    print(f"{key}: {value}")

print(f"\n📊 Model Architecture:")
print(ann_model)

# Test forward pass
sample_input = X_clf[:5]  # First 5 samples
with torch.no_grad():
    sample_output = ann_model(sample_input)
    
print(f"\n🔍 Forward Pass Test:")
print(f"Input shape: {sample_input.shape}")
print(f"Output shape: {sample_output.shape}")
print(f"Sample outputs (raw logits):")
for i, output in enumerate(sample_output):
    predicted_class = torch.argmax(output).item()
    actual_class = y_clf[i].item()
    print(f"  Sample {i+1}: Predicted={predicted_class}, Actual={actual_class}")

print("\n✅ ANN implementation complete!")

## 🖼️ 2. Simple Convolutional Neural Network (CNN)

**Key Concepts:**
- **Convolutional layers** for feature extraction
- **Pooling layers** for dimensionality reduction
- **Feature maps** and **spatial hierarchies**
- **Parameter sharing** across spatial locations

**Architecture:**
```
Input Image → Conv → ReLU → Pool → Conv → ReLU → Pool → Flatten → FC → Output
    ↓          ↓      ↓      ↓      ↓      ↓      ↓       ↓       ↓      ↓
  (1,28,28) (32,26,26) → (32,13,13) (64,11,11) → (64,5,5) → (1600) → (128) → (3)
```

In [None]:
# Simple Convolutional Neural Network Implementation

class SimpleCNN(nn.Module):
    """
    Simple Convolutional Neural Network for image classification.
    
    Architecture: Conv → ReLU → Pool → Conv → ReLU → Pool → Flatten → FC → Output
    """
    
    def __init__(self, input_channels=1, num_classes=3):
        super(SimpleCNN, self).__init__()
        
        # Store architecture info
        self.input_channels = input_channels
        self.num_classes = num_classes
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)  # (1,28,28) → (32,28,28)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)                    # (32,28,28) → (32,14,14)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)              # (32,14,14) → (64,14,14)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)                    # (64,14,14) → (64,7,7)
        
        # Calculate flattened size (depends on input image size)
        # For 28x28 input: 64 * 7 * 7 = 3136
        self.flattened_size = 64 * 7 * 7
        
        # Fully connected layers
        self.fc1 = nn.Linear(self.flattened_size, 128)
        self.fc2 = nn.Linear(128, num_classes)
        
        # Activation and dropout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        """
        Forward pass through CNN.
        
        Args:
            x: Input tensor (batch_size, channels, height, width)
        
        Returns:
            Output tensor (batch_size, num_classes)
        """
        # First convolutional block
        x = self.conv1(x)      # Apply convolution
        x = self.relu(x)       # Apply activation
        x = self.pool1(x)      # Apply pooling
        
        # Second convolutional block
        x = self.conv2(x)      # Apply convolution
        x = self.relu(x)       # Apply activation
        x = self.pool2(x)      # Apply pooling
        
        # Flatten for fully connected layers
        x = x.view(x.size(0), -1)  # Flatten: (batch_size, channels*height*width)
        
        # Fully connected layers
        x = self.fc1(x)        # First FC layer
        x = self.relu(x)       # Activation
        x = self.dropout(x)    # Dropout for regularization
        x = self.fc2(x)        # Output layer
        
        return x
    
    def get_info(self):
        """Return model information."""
        total_params = sum(p.numel() for p in self.parameters())
        return {
            'Architecture': 'Convolutional Neural Network',
            'Input Channels': self.input_channels,
            'Output Classes': self.num_classes,
            'Total Parameters': total_params,
            'Flattened Size': self.flattened_size
        }

# Create and test CNN
print("🖼️ Simple CNN Implementation")
print("=" * 40)

# Model parameters
input_channels = X_img.shape[1]  # Number of channels
num_classes = len(torch.unique(y_img))  # Number of classes

# Create model
cnn_model = SimpleCNN(input_channels, num_classes)

# Display model info
model_info = cnn_model.get_info()
for key, value in model_info.items():
    print(f"{key}: {value}")

print(f"\n📊 Model Architecture:")
print(cnn_model)

# Test forward pass
sample_input = X_img[:5]  # First 5 images
print(f"\n🔍 Forward Pass Test:")
print(f"Input shape: {sample_input.shape}")

with torch.no_grad():
    sample_output = cnn_model(sample_input)
    
print(f"Output shape: {sample_output.shape}")
print(f"Sample outputs (raw logits):")
for i, output in enumerate(sample_output):
    predicted_class = torch.argmax(output).item()
    actual_class = y_img[i].item()
    confidence = torch.softmax(output, dim=0)[predicted_class].item()
    print(f"  Sample {i+1}: Predicted={predicted_class} (conf: {confidence:.3f}), Actual={actual_class}")

# Visualize feature maps
def visualize_feature_maps(model, input_image, layer_name='conv1'):
    """Visualize feature maps from convolutional layers."""
    model.eval()
    
    # Get activations from first conv layer
    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    # Register hook
    if layer_name == 'conv1':
        model.conv1.register_forward_hook(get_activation('conv1'))
    else:
        model.conv2.register_forward_hook(get_activation('conv2'))
    
    # Forward pass
    with torch.no_grad():
        _ = model(input_image.unsqueeze(0))
    
    # Get feature maps
    feature_maps = activation[layer_name][0]  # Remove batch dimension
    
    # Plot first 8 feature maps
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()
    
    for i in range(min(8, feature_maps.shape[0])):
        axes[i].imshow(feature_maps[i].numpy(), cmap='viridis')
        axes[i].set_title(f'Feature Map {i+1}')
        axes[i].axis('off')
    
    plt.suptitle(f'{layer_name.upper()} Feature Maps')
    plt.tight_layout()
    plt.show()

# Visualize feature maps for first sample
print(f"\n🎨 Visualizing Feature Maps:")
visualize_feature_maps(cnn_model, X_img[0], 'conv1')

print("\n✅ CNN implementation complete!")

## 🔄 3. Simple Recurrent Neural Network (RNN)

**Key Concepts:**
- **Sequential processing** with hidden state memory
- **Temporal dependencies** across time steps
- **Backpropagation Through Time (BPTT)**
- **Vanishing gradient problem** for long sequences

**Architecture:**
```
h₀ → [RNN Cell] → h₁ → [RNN Cell] → h₂ → ... → hₜ → [Linear] → Output
     ↑    ↓              ↑    ↓              ↑    ↓
    x₁   o₁             x₂   o₂             xₜ   oₜ

Where: hₜ = tanh(Wᵢₕxₜ + Wₕₕhₜ₋₁ + bₕ)
```

In [None]:
# Simple Recurrent Neural Network Implementation

class SimpleRNN(nn.Module):
    """
    Simple Recurrent Neural Network for sequence classification.
    
    Architecture: RNN → Linear → Output
    """
    
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(SimpleRNN, self).__init__()
        
        # Store architecture info
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        
        # RNN layer
        self.rnn = nn.RNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,  # Input shape: (batch, seq, feature)
            nonlinearity='tanh'  # Default activation
        )
        
        # Output layer
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, hidden=None):
        """
        Forward pass through RNN.
        
        Args:
            x: Input tensor (batch_size, seq_len, input_size)
            hidden: Initial hidden state (optional)
        
        Returns:
            Output tensor (batch_size, output_size)
            Final hidden state
        """
        # RNN forward pass
        # Output: (batch_size, seq_len, hidden_size)
        # Hidden: (num_layers, batch_size, hidden_size)
        rnn_out, hidden_final = self.rnn(x, hidden)
        
        # Use the last time step output for classification
        # Take output from last time step: (batch_size, hidden_size)
        last_output = rnn_out[:, -1, :]
        
        # Apply linear layer to get class predictions
        output = self.fc(last_output)
        
        return output, hidden_final
    
    def init_hidden(self, batch_size):
        """Initialize hidden state with zeros."""
        return torch.zeros(self.num_layers, batch_size, self.hidden_size)
    
    def get_info(self):
        """Return model information."""
        total_params = sum(p.numel() for p in self.parameters())
        return {
            'Architecture': 'Recurrent Neural Network',
            'Input Size': self.input_size,
            'Hidden Size': self.hidden_size,
            'Output Size': self.output_size,
            'Num Layers': self.num_layers,
            'Total Parameters': total_params
        }

# Create and test RNN
print("🔄 Simple RNN Implementation")
print("=" * 40)

# Model parameters
input_size = X_seq.shape[2]      # Number of features per time step
hidden_size = 32
output_size = len(torch.unique(y_seq))  # Number of classes
seq_len = X_seq.shape[1]         # Sequence length

# Create model
rnn_model = SimpleRNN(input_size, hidden_size, output_size, num_layers=1)

# Display model info
model_info = rnn_model.get_info()
for key, value in model_info.items():
    print(f"{key}: {value}")

print(f"\n📊 Model Architecture:")
print(rnn_model)

# Test forward pass
sample_input = X_seq[:5]  # First 5 sequences
batch_size = sample_input.shape[0]

print(f"\n🔍 Forward Pass Test:")
print(f"Input shape: {sample_input.shape}")

# Initialize hidden state
hidden_init = rnn_model.init_hidden(batch_size)
print(f"Initial hidden shape: {hidden_init.shape}")

with torch.no_grad():
    sample_output, final_hidden = rnn_model(sample_input, hidden_init)
    
print(f"Output shape: {sample_output.shape}")
print(f"Final hidden shape: {final_hidden.shape}")

print(f"Sample outputs (raw logits):")
for i, output in enumerate(sample_output):
    predicted_class = torch.argmax(output).item()
    actual_class = y_seq[i].item()
    confidence = torch.softmax(output, dim=0)[predicted_class].item()
    print(f"  Sample {i+1}: Predicted={predicted_class} (conf: {confidence:.3f}), Actual={actual_class}")

# Demonstrate hidden state evolution
def visualize_hidden_states(model, sequence):
    """Visualize how hidden states evolve over time."""
    model.eval()
    
    # Get hidden states at each time step
    hidden_states = []
    hidden = model.init_hidden(1)  # Batch size 1
    
    sequence = sequence.unsqueeze(0)  # Add batch dimension
    
    with torch.no_grad():
        for t in range(sequence.shape[1]):  # For each time step
            # Forward pass for single time step
            input_t = sequence[:, t:t+1, :]  # (1, 1, input_size)
            _, hidden = model.rnn(input_t, hidden)
            hidden_states.append(hidden[0, 0, :].numpy())  # Remove batch and layer dims
    
    hidden_states = np.array(hidden_states)  # (seq_len, hidden_size)
    
    # Plot hidden state evolution
    plt.figure(figsize=(12, 8))
    
    # Plot all hidden dimensions over time
    plt.subplot(2, 1, 1)
    plt.imshow(hidden_states.T, aspect='auto', cmap='viridis')
    plt.colorbar(label='Activation Value')
    plt.title('Hidden State Evolution Over Time')
    plt.xlabel('Time Steps')
    plt.ylabel('Hidden Dimensions')
    
    # Plot mean hidden state magnitude over time
    plt.subplot(2, 1, 2)
    hidden_magnitudes = np.linalg.norm(hidden_states, axis=1)
    plt.plot(hidden_magnitudes, 'b-', linewidth=2)
    plt.title('Hidden State Magnitude Over Time')
    plt.xlabel('Time Steps')
    plt.ylabel('||h(t)||')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Visualize hidden states for first sequence
print(f"\n🎨 Visualizing Hidden State Evolution:")
visualize_hidden_states(rnn_model, X_seq[0])

print("\n✅ RNN implementation complete!")

## 🧮 4. Simple Long Short-Term Memory (LSTM)

**Key Concepts:**
- **Memory cells** with controlled information flow
- **Gating mechanisms** (forget, input, output gates)
- **Long-term dependency** handling
- **Gradient flow** through additive cell state

**Architecture:**
```
    Input Gate     Forget Gate    Output Gate
        ↓              ↓              ↓
    σ(Wᵢ[h,x]+bᵢ)  σ(Wf[h,x]+bf)  σ(Wo[h,x]+bo)
        ↓              ↓              ↓
       i_t  ←→  C̃_t  ←→  f_t         o_t
                 ↓                   ↓
             C_t = f_t*C_{t-1} + i_t*C̃_t
                               ↓
                         h_t = o_t * tanh(C_t)
```

In [None]:
# Simple Long Short-Term Memory Implementation

class SimpleLSTM(nn.Module):
    """
    Simple LSTM Network for sequence classification.
    
    Architecture: LSTM → Linear → Output
    """
    
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(SimpleLSTM, self).__init__()
        
        # Store architecture info
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        
        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,  # Input shape: (batch, seq, feature)
            dropout=0.2 if num_layers > 1 else 0
        )
        
        # Output layer
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, hidden=None):
        """
        Forward pass through LSTM.
        
        Args:
            x: Input tensor (batch_size, seq_len, input_size)
            hidden: Tuple of (h_0, c_0) initial states (optional)
        
        Returns:
            Output tensor (batch_size, output_size)
            Final hidden and cell states
        """
        # LSTM forward pass
        # Output: (batch_size, seq_len, hidden_size)
        # Hidden: (h_n, c_n) where each is (num_layers, batch_size, hidden_size)
        lstm_out, hidden_final = self.lstm(x, hidden)
        
        # Use the last time step output for classification
        # Take output from last time step: (batch_size, hidden_size)
        last_output = lstm_out[:, -1, :]
        
        # Apply linear layer to get class predictions
        output = self.fc(last_output)
        
        return output, hidden_final
    
    def init_hidden(self, batch_size):
        """Initialize hidden and cell states with zeros."""
        h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
        c_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size)
        return (h_0, c_0)
    
    def get_info(self):
        """Return model information."""
        total_params = sum(p.numel() for p in self.parameters())
        return {
            'Architecture': 'Long Short-Term Memory',
            'Input Size': self.input_size,
            'Hidden Size': self.hidden_size,
            'Output Size': self.output_size,
            'Num Layers': self.num_layers,
            'Total Parameters': total_params
        }

# Create and test LSTM
print("🧮 Simple LSTM Implementation")
print("=" * 40)

# Model parameters (same as RNN for comparison)
input_size = X_seq.shape[2]
hidden_size = 32
output_size = len(torch.unique(y_seq))

# Create model
lstm_model = SimpleLSTM(input_size, hidden_size, output_size, num_layers=1)

# Display model info
model_info = lstm_model.get_info()
for key, value in model_info.items():
    print(f"{key}: {value}")

print(f"\n📊 Model Architecture:")
print(lstm_model)

# Test forward pass
sample_input = X_seq[:5]
batch_size = sample_input.shape[0]

print(f"\n🔍 Forward Pass Test:")
print(f"Input shape: {sample_input.shape}")

# Initialize hidden and cell states
hidden_init = lstm_model.init_hidden(batch_size)
print(f"Initial hidden shape: {hidden_init[0].shape}")
print(f"Initial cell shape: {hidden_init[1].shape}")

with torch.no_grad():
    sample_output, final_states = lstm_model(sample_input, hidden_init)
    
print(f"Output shape: {sample_output.shape}")
print(f"Final hidden shape: {final_states[0].shape}")
print(f"Final cell shape: {final_states[1].shape}")

print(f"Sample outputs (raw logits):")
for i, output in enumerate(sample_output):
    predicted_class = torch.argmax(output).item()
    actual_class = y_seq[i].item()
    confidence = torch.softmax(output, dim=0)[predicted_class].item()
    print(f"  Sample {i+1}: Predicted={predicted_class} (conf: {confidence:.3f}), Actual={actual_class}")

# Demonstrate LSTM states evolution
def visualize_lstm_states(model, sequence):
    """Visualize LSTM hidden and cell states evolution."""
    model.eval()
    
    hidden_states = []
    cell_states = []
    
    # Initialize states
    h_t, c_t = model.init_hidden(1)
    sequence = sequence.unsqueeze(0)  # Add batch dimension
    
    with torch.no_grad():
        for t in range(sequence.shape[1]):
            # Forward pass for single time step
            input_t = sequence[:, t:t+1, :]
            _, (h_t, c_t) = model.lstm(input_t, (h_t, c_t))
            
            hidden_states.append(h_t[0, 0, :].numpy())
            cell_states.append(c_t[0, 0, :].numpy())
    
    hidden_states = np.array(hidden_states)
    cell_states = np.array(cell_states)
    
    # Plot states evolution
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Hidden states heatmap
    im1 = axes[0,0].imshow(hidden_states.T, aspect='auto', cmap='viridis')
    axes[0,0].set_title('Hidden States Evolution')
    axes[0,0].set_xlabel('Time Steps')
    axes[0,0].set_ylabel('Hidden Dimensions')
    plt.colorbar(im1, ax=axes[0,0])
    
    # Cell states heatmap
    im2 = axes[0,1].imshow(cell_states.T, aspect='auto', cmap='plasma')
    axes[0,1].set_title('Cell States Evolution')
    axes[0,1].set_xlabel('Time Steps')
    axes[0,1].set_ylabel('Cell Dimensions')
    plt.colorbar(im2, ax=axes[0,1])
    
    # States magnitude over time
    hidden_magnitudes = np.linalg.norm(hidden_states, axis=1)
    cell_magnitudes = np.linalg.norm(cell_states, axis=1)
    
    axes[1,0].plot(hidden_magnitudes, 'b-', linewidth=2, label='Hidden State')
    axes[1,0].plot(cell_magnitudes, 'r-', linewidth=2, label='Cell State')
    axes[1,0].set_title('State Magnitudes Over Time')
    axes[1,0].set_xlabel('Time Steps')
    axes[1,0].set_ylabel('||State||')
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # States comparison
    axes[1,1].plot(hidden_states[:, 0], 'b-', alpha=0.7, label='Hidden[0]')
    axes[1,1].plot(cell_states[:, 0], 'r-', alpha=0.7, label='Cell[0]')
    axes[1,1].plot(hidden_states[:, 1], 'b--', alpha=0.7, label='Hidden[1]')
    axes[1,1].plot(cell_states[:, 1], 'r--', alpha=0.7, label='Cell[1]')
    axes[1,1].set_title('Sample Dimensions Evolution')
    axes[1,1].set_xlabel('Time Steps')
    axes[1,1].set_ylabel('State Value')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Visualize LSTM states for first sequence
print(f"\n🎨 Visualizing LSTM States Evolution:")
visualize_lstm_states(lstm_model, X_seq[0])

print("\n✅ LSTM implementation complete!")

## ⚡ 5. Simple Gated Recurrent Unit (GRU)

**Key Concepts:**
- **Simplified gating** compared to LSTM (2 gates vs 3)
- **Reset and update gates** control information flow
- **No separate cell state** (combined with hidden state)
- **Computational efficiency** with similar performance

**Architecture:**
```
    Reset Gate         Update Gate
        ↓                  ↓
    σ(Wr[h,x]+br)     σ(Wu[h,x]+bu)
        ↓                  ↓
       r_t               z_t
        ↓                  ↓
   h̃_t = tanh(Wh[r_t*h_{t-1}, x_t] + bh)
                    ↓
         h_t = (1-z_t)*h_{t-1} + z_t*h̃_t
```

In [None]:
# Simple Gated Recurrent Unit Implementation

class SimpleGRU(nn.Module):
    """
    Simple GRU Network for sequence classification.
    
    Architecture: GRU → Linear → Output
    """
    
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(SimpleGRU, self).__init__()
        
        # Store architecture info
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        
        # GRU layer
        self.gru = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,  # Input shape: (batch, seq, feature)
            dropout=0.2 if num_layers > 1 else 0
        )
        
        # Output layer
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, hidden=None):
        """
        Forward pass through GRU.
        
        Args:
            x: Input tensor (batch_size, seq_len, input_size)
            hidden: Initial hidden state (optional)
        
        Returns:
            Output tensor (batch_size, output_size)
            Final hidden state
        """
        # GRU forward pass
        # Output: (batch_size, seq_len, hidden_size)
        # Hidden: (num_layers, batch_size, hidden_size)
        gru_out, hidden_final = self.gru(x, hidden)
        
        # Use the last time step output for classification
        # Take output from last time step: (batch_size, hidden_size)
        last_output = gru_out[:, -1, :]
        
        # Apply linear layer to get class predictions
        output = self.fc(last_output)
        
        return output, hidden_final
    
    def init_hidden(self, batch_size):
        """Initialize hidden state with zeros."""
        return torch.zeros(self.num_layers, batch_size, self.hidden_size)
    
    def get_info(self):
        """Return model information."""
        total_params = sum(p.numel() for p in self.parameters())
        return {
            'Architecture': 'Gated Recurrent Unit',
            'Input Size': self.input_size,
            'Hidden Size': self.hidden_size,
            'Output Size': self.output_size,
            'Num Layers': self.num_layers,
            'Total Parameters': total_params
        }

# Create and test GRU
print("⚡ Simple GRU Implementation")
print("=" * 40)

# Model parameters (same as RNN/LSTM for comparison)
input_size = X_seq.shape[2]
hidden_size = 32
output_size = len(torch.unique(y_seq))

# Create model
gru_model = SimpleGRU(input_size, hidden_size, output_size, num_layers=1)

# Display model info
model_info = gru_model.get_info()
for key, value in model_info.items():
    print(f"{key}: {value}")

print(f"\n📊 Model Architecture:")
print(gru_model)

# Test forward pass
sample_input = X_seq[:5]
batch_size = sample_input.shape[0]

print(f"\n🔍 Forward Pass Test:")
print(f"Input shape: {sample_input.shape}")

# Initialize hidden state
hidden_init = gru_model.init_hidden(batch_size)
print(f"Initial hidden shape: {hidden_init.shape}")

with torch.no_grad():
    sample_output, final_hidden = gru_model(sample_input, hidden_init)
    
print(f"Output shape: {sample_output.shape}")
print(f"Final hidden shape: {final_hidden.shape}")

print(f"Sample outputs (raw logits):")
for i, output in enumerate(sample_output):
    predicted_class = torch.argmax(output).item()
    actual_class = y_seq[i].item()
    confidence = torch.softmax(output, dim=0)[predicted_class].item()
    print(f"  Sample {i+1}: Predicted={predicted_class} (conf: {confidence:.3f}), Actual={actual_class}")

# Compare model complexities
print(f"\n📊 Model Comparison:")
print("=" * 50)

models_comparison = {
    'RNN': rnn_model,
    'LSTM': lstm_model,
    'GRU': gru_model
}

for name, model in models_comparison.items():
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{name:>6}: {total_params:>6,} parameters")

# Visualize GRU gates behavior (conceptual)
def compare_architectures():
    """Visual comparison of RNN architectures."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # RNN
    axes[0].text(0.5, 0.7, 'RNN', ha='center', va='center', fontsize=16, fontweight='bold')
    axes[0].text(0.5, 0.5, 'h_t = tanh(W_h[h_{t-1}, x_t] + b_h)', ha='center', va='center', fontsize=10)
    axes[0].text(0.5, 0.3, '• Simple recurrence\n• Vanishing gradients\n• Fast computation', ha='center', va='center', fontsize=9)
    axes[0].set_title('Vanilla RNN')
    axes[0].set_xlim(0, 1)
    axes[0].set_ylim(0, 1)
    axes[0].axis('off')
    
    # LSTM
    axes[1].text(0.5, 0.8, 'LSTM', ha='center', va='center', fontsize=16, fontweight='bold')
    axes[1].text(0.5, 0.65, '• Forget Gate: f_t = σ(W_f[h,x] + b_f)', ha='center', va='center', fontsize=8)
    axes[1].text(0.5, 0.55, '• Input Gate: i_t = σ(W_i[h,x] + b_i)', ha='center', va='center', fontsize=8)
    axes[1].text(0.5, 0.45, '• Output Gate: o_t = σ(W_o[h,x] + b_o)', ha='center', va='center', fontsize=8)
    axes[1].text(0.5, 0.35, '• Cell State: C_t = f_t*C_{t-1} + i_t*C̃_t', ha='center', va='center', fontsize=8)
    axes[1].text(0.5, 0.2, '• Long-term memory\n• 3 gates + cell state\n• More parameters', ha='center', va='center', fontsize=9)
    axes[1].set_title('LSTM')
    axes[1].set_xlim(0, 1)
    axes[1].set_ylim(0, 1)
    axes[1].axis('off')
    
    # GRU
    axes[2].text(0.5, 0.8, 'GRU', ha='center', va='center', fontsize=16, fontweight='bold')
    axes[2].text(0.5, 0.65, '• Reset Gate: r_t = σ(W_r[h,x] + b_r)', ha='center', va='center', fontsize=8)
    axes[2].text(0.5, 0.55, '• Update Gate: z_t = σ(W_z[h,x] + b_z)', ha='center', va='center', fontsize=8)
    axes[2].text(0.5, 0.45, '• New State: h̃_t = tanh(W_h[r_t*h,x])', ha='center', va='center', fontsize=8)
    axes[2].text(0.5, 0.35, '• Final: h_t = (1-z_t)*h_{t-1} + z_t*h̃_t', ha='center', va='center', fontsize=8)
    axes[2].text(0.5, 0.2, '• Efficient gating\n• 2 gates only\n• Fewer parameters', ha='center', va='center', fontsize=9)
    axes[2].set_title('GRU')
    axes[2].set_xlim(0, 1)
    axes[2].set_ylim(0, 1)
    axes[2].axis('off')
    
    plt.suptitle('RNN Architecture Comparison', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

print(f"\n🎨 Architecture Comparison:")
compare_architectures()

print("\n✅ GRU implementation complete!")

## 🏋️ 6. Training & Evaluation Functions

**Key Training Concepts:**
- **Loss functions** (CrossEntropy, MSE)
- **Optimizers** (SGD, Adam)
- **Learning rate scheduling**
- **Gradient clipping** for RNNs
- **Validation and metrics**

In [None]:
# Training and Evaluation Functions

class ModelTrainer:
    """Generic trainer for all neural network models."""
    
    def __init__(self, model, device='cpu'):
        self.model = model.to(device)
        self.device = device
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []
    
    def train_epoch(self, train_loader, criterion, optimizer, clip_grad=None):
        """Train for one epoch."""
        self.model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)
            
            # Forward pass
            optimizer.zero_grad()
            
            # Handle different model types
            if hasattr(self.model, 'init_hidden'):  # RNN-based models
                batch_size = data.size(0)
                if isinstance(self.model, SimpleLSTM):
                    hidden = self.model.init_hidden(batch_size)
                    hidden = (hidden[0].to(self.device), hidden[1].to(self.device))
                else:
                    hidden = self.model.init_hidden(batch_size).to(self.device)
                output, _ = self.model(data, hidden)
            else:  # ANN/CNN models
                output = self.model(data)
            
            # Compute loss
            loss = criterion(output, target)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping for RNN models
            if clip_grad:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad)
            
            optimizer.step()
            
            # Statistics
            total_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
        
        avg_loss = total_loss / len(train_loader)
        accuracy = 100. * correct / total
        
        self.train_losses.append(avg_loss)
        self.train_accuracies.append(accuracy)
        
        return avg_loss, accuracy
    
    def validate(self, val_loader, criterion):
        """Validate the model."""
        self.model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)
                
                # Handle different model types
                if hasattr(self.model, 'init_hidden'):  # RNN-based models
                    batch_size = data.size(0)
                    if isinstance(self.model, SimpleLSTM):
                        hidden = self.model.init_hidden(batch_size)
                        hidden = (hidden[0].to(self.device), hidden[1].to(self.device))
                    else:
                        hidden = self.model.init_hidden(batch_size).to(self.device)
                    output, _ = self.model(data, hidden)
                else:  # ANN/CNN models
                    output = self.model(data)
                
                # Compute loss
                loss = criterion(output, target)
                total_loss += loss.item()
                
                # Statistics
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        avg_loss = total_loss / len(val_loader)
        accuracy = 100. * correct / total
        
        self.val_losses.append(avg_loss)
        self.val_accuracies.append(accuracy)
        
        return avg_loss, accuracy
    
    def train(self, train_loader, val_loader, num_epochs=10, lr=0.001, clip_grad=None):
        """Complete training loop."""
        # Define loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        print(f"🚀 Training {self.model.__class__.__name__} for {num_epochs} epochs...")
        print(f"📊 Training samples: {len(train_loader.dataset)}")
        print(f"📊 Validation samples: {len(val_loader.dataset)}")
        print("-" * 60)
        
        best_val_acc = 0.0
        start_time = time.time()
        
        for epoch in range(num_epochs):
            # Training
            train_loss, train_acc = self.train_epoch(train_loader, criterion, optimizer, clip_grad)
            
            # Validation
            val_loss, val_acc = self.validate(val_loader, criterion)
            
            # Print progress
            print(f"Epoch {epoch+1:2d}/{num_epochs} | "
                  f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
                  f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                # torch.save(self.model.state_dict(), 'best_model.pth')
        
        training_time = time.time() - start_time
        print(f"\n✅ Training completed in {training_time:.2f} seconds")
        print(f"🏆 Best validation accuracy: {best_val_acc:.2f}%")
        
        return {
            'best_val_acc': best_val_acc,
            'final_train_acc': self.train_accuracies[-1],
            'final_val_acc': self.val_accuracies[-1],
            'training_time': training_time
        }
    
    def plot_training_history(self):
        """Plot training and validation metrics."""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        epochs = range(1, len(self.train_losses) + 1)
        
        # Loss plot
        ax1.plot(epochs, self.train_losses, 'b-', label='Training Loss', linewidth=2)
        ax1.plot(epochs, self.val_losses, 'r-', label='Validation Loss', linewidth=2)
        ax1.set_title('Training and Validation Loss')
        ax1.set_xlabel('Epochs')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Accuracy plot
        ax2.plot(epochs, self.train_accuracies, 'b-', label='Training Accuracy', linewidth=2)
        ax2.plot(epochs, self.val_accuracies, 'r-', label='Validation Accuracy', linewidth=2)
        ax2.set_title('Training and Validation Accuracy')
        ax2.set_xlabel('Epochs')
        ax2.set_ylabel('Accuracy (%)')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

def create_data_loaders(X, y, batch_size=32, train_split=0.8):
    """Create train and validation data loaders."""
    # Create dataset
    dataset = TensorDataset(X, y)
    
    # Split dataset
    train_size = int(train_split * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader

print("🏋️ Training and Evaluation Functions Ready!")
print("=" * 50)

# Demonstrate training setup
sample_loader = DataLoader(TensorDataset(X_clf[:100], y_clf[:100]), batch_size=16)
sample_trainer = ModelTrainer(ann_model, device)

print(f"✅ Trainer created for {ann_model.__class__.__name__}")
print(f"📊 Device: {device}")
print(f"🔧 Sample batch size: 16")
print(f"💾 Ready for training all models!")

## 🚀 7. Test All Models

Now let's train and compare all our neural network implementations!

In [None]:
# Test All Neural Network Models

# Training parameters
EPOCHS = 5  # Reduced for demo
BATCH_SIZE = 32
LEARNING_RATE = 0.001

results = {}

print("🚀 Training All Neural Network Models")
print("=" * 60)

# 1. Test ANN on Classification Data
print("\n🧠 1. Testing ANN on Classification Data")
print("-" * 40)

# Create fresh ANN model
ann_model = SimpleANN(X_clf.shape[1], 64, len(torch.unique(y_clf)))
ann_train_loader, ann_val_loader = create_data_loaders(X_clf, y_clf, BATCH_SIZE)

# Train ANN
ann_trainer = ModelTrainer(ann_model, device)
ann_results = ann_trainer.train(ann_train_loader, ann_val_loader, EPOCHS, LEARNING_RATE)
results['ANN'] = ann_results

# Plot training history
ann_trainer.plot_training_history()

# 2. Test CNN on Image Data
print("\n🖼️ 2. Testing CNN on Image Data")
print("-" * 40)

# Create fresh CNN model
cnn_model = SimpleCNN(X_img.shape[1], len(torch.unique(y_img)))
cnn_train_loader, cnn_val_loader = create_data_loaders(X_img, y_img, BATCH_SIZE)

# Train CNN
cnn_trainer = ModelTrainer(cnn_model, device)
cnn_results = cnn_trainer.train(cnn_train_loader, cnn_val_loader, EPOCHS, LEARNING_RATE)
results['CNN'] = cnn_results

# Plot training history
cnn_trainer.plot_training_history()

# 3. Test RNN on Sequence Data
print("\n🔄 3. Testing RNN on Sequence Data")
print("-" * 40)

# Create fresh RNN model
rnn_model = SimpleRNN(X_seq.shape[2], 32, len(torch.unique(y_seq)))
rnn_train_loader, rnn_val_loader = create_data_loaders(X_seq, y_seq, BATCH_SIZE)

# Train RNN (with gradient clipping)
rnn_trainer = ModelTrainer(rnn_model, device)
rnn_results = rnn_trainer.train(rnn_train_loader, rnn_val_loader, EPOCHS, LEARNING_RATE, clip_grad=1.0)
results['RNN'] = rnn_results

# Plot training history
rnn_trainer.plot_training_history()

# 4. Test LSTM on Sequence Data
print("\n🧮 4. Testing LSTM on Sequence Data")
print("-" * 40)

# Create fresh LSTM model
lstm_model = SimpleLSTM(X_seq.shape[2], 32, len(torch.unique(y_seq)))
lstm_train_loader, lstm_val_loader = create_data_loaders(X_seq, y_seq, BATCH_SIZE)

# Train LSTM
lstm_trainer = ModelTrainer(lstm_model, device)
lstm_results = lstm_trainer.train(lstm_train_loader, lstm_val_loader, EPOCHS, LEARNING_RATE, clip_grad=1.0)
results['LSTM'] = lstm_results

# Plot training history
lstm_trainer.plot_training_history()

# 5. Test GRU on Sequence Data
print("\n⚡ 5. Testing GRU on Sequence Data")
print("-" * 40)

# Create fresh GRU model
gru_model = SimpleGRU(X_seq.shape[2], 32, len(torch.unique(y_seq)))
gru_train_loader, gru_val_loader = create_data_loaders(X_seq, y_seq, BATCH_SIZE)

# Train GRU
gru_trainer = ModelTrainer(gru_model, device)
gru_results = gru_trainer.train(gru_train_loader, gru_val_loader, EPOCHS, LEARNING_RATE, clip_grad=1.0)
results['GRU'] = gru_results

# Plot training history
gru_trainer.plot_training_history()

# Final Results Comparison
print("\n📊 FINAL RESULTS COMPARISON")
print("=" * 60)

# Create comparison table
import pandas as pd

comparison_data = []
for model_name, result in results.items():
    comparison_data.append({
        'Model': model_name,
        'Best Val Acc (%)': f"{result['best_val_acc']:.2f}",
        'Final Train Acc (%)': f"{result['final_train_acc']:.2f}",
        'Final Val Acc (%)': f"{result['final_val_acc']:.2f}",
        'Training Time (s)': f"{result['training_time']:.2f}"
    })

df = pd.DataFrame(comparison_data)
print(df.to_string(index=False))

# Model parameter comparison
print(f"\n🔧 MODEL COMPLEXITY COMPARISON")
print("-" * 40)

models_for_comparison = [
    ('ANN', ann_model),
    ('CNN', cnn_model), 
    ('RNN', rnn_model),
    ('LSTM', lstm_model),
    ('GRU', gru_model)
]

for name, model in models_for_comparison:
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{name:>6}: {total_params:>8,} total params ({trainable_params:>8,} trainable)")

# Performance visualization
def plot_final_comparison():
    """Plot final performance comparison."""
    model_names = list(results.keys())
    val_accuracies = [results[name]['best_val_acc'] for name in model_names]
    training_times = [results[name]['training_time'] for name in model_names]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Accuracy comparison
    bars1 = ax1.bar(model_names, val_accuracies, alpha=0.8, color=['blue', 'green', 'red', 'orange', 'purple'])
    ax1.set_title('Best Validation Accuracy Comparison')
    ax1.set_ylabel('Accuracy (%)')
    ax1.set_ylim(0, 100)
    
    # Add value labels on bars
    for bar, acc in zip(bars1, val_accuracies):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    # Training time comparison
    bars2 = ax2.bar(model_names, training_times, alpha=0.8, color=['blue', 'green', 'red', 'orange', 'purple'])
    ax2.set_title('Training Time Comparison')
    ax2.set_ylabel('Time (seconds)')
    
    # Add value labels on bars
    for bar, time_val in zip(bars2, training_times):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.1,
                f'{time_val:.1f}s', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

print(f"\n🎨 Performance Visualization:")
plot_final_comparison()

# Key takeaways
print(f"\n📚 KEY TAKEAWAYS FOR EXAM")
print("=" * 50)
print("🧠 ANN: Simple feedforward, good for tabular data")
print("🖼️ CNN: Excellent for images, spatial feature extraction")
print("🔄 RNN: Sequential data, but prone to vanishing gradients") 
print("🧮 LSTM: Best for long sequences, handles long-term dependencies")
print("⚡ GRU: Efficient alternative to LSTM, fewer parameters")
print("\n🎯 Model Selection Guidelines:")
print("• Tabular/structured data → ANN")
print("• Images/spatial data → CNN") 
print("• Short sequences → RNN")
print("• Long sequences → LSTM/GRU")
print("• Resource constrained → GRU over LSTM")

print("\n✅ ALL NEURAL NETWORK IMPLEMENTATIONS COMPLETE!")
print("🎓 Ready for your exam! 🚀")

## 📚 Exam Cheat Sheet - Neural Network Quick Reference

### 🧠 **ANN (Artificial Neural Network)**
```python
# Key Components:
- Linear layers: nn.Linear(input_size, output_size)
- Activation: ReLU, Sigmoid, Tanh
- Loss: CrossEntropyLoss, MSELoss
- Optimizer: Adam, SGD

# When to use: Tabular data, classification/regression
# Pros: Simple, fast training
# Cons: No spatial/temporal awareness
```

### 🖼️ **CNN (Convolutional Neural Network)**
```python
# Key Components:
- Conv2d: nn.Conv2d(in_channels, out_channels, kernel_size)
- Pooling: nn.MaxPool2d(kernel_size)
- Flatten: x.view(x.size(0), -1)

# When to use: Images, spatial data
# Pros: Translation invariant, parameter sharing
# Cons: Limited to grid-like data
```

### 🔄 **RNN (Recurrent Neural Network)**
```python
# Key Components:
- Hidden state: h_t = tanh(W_ih * x_t + W_hh * h_{t-1} + b)
- Sequential processing
- Vanishing gradient problem

# When to use: Short sequences, simple temporal patterns
# Pros: Memory of previous inputs
# Cons: Vanishing gradients, slow training
```

### 🧮 **LSTM (Long Short-Term Memory)**
```python
# Key Components:
- Forget gate: f_t = σ(W_f * [h_{t-1}, x_t] + b_f)
- Input gate: i_t = σ(W_i * [h_{t-1}, x_t] + b_i)
- Output gate: o_t = σ(W_o * [h_{t-1}, x_t] + b_o)
- Cell state: C_t = f_t * C_{t-1} + i_t * tanh(W_C * [h_{t-1}, x_t] + b_C)

# When to use: Long sequences, language modeling
# Pros: Handles long-term dependencies, no vanishing gradients
# Cons: More parameters, slower than GRU
```

### ⚡ **GRU (Gated Recurrent Unit)**
```python
# Key Components:
- Reset gate: r_t = σ(W_r * [h_{t-1}, x_t])
- Update gate: z_t = σ(W_z * [h_{t-1}, x_t])
- New memory: ñ_t = tanh(W * [r_t * h_{t-1}, x_t])
- Hidden state: h_t = (1 - z_t) * h_{t-1} + z_t * ñ_t

# When to use: Long sequences, resource constraints
# Pros: Fewer parameters than LSTM, faster training
# Cons: May not capture very long dependencies as well as LSTM
```

### 🎯 **Quick Decision Tree for Model Selection**
```
Data Type?
├── Tabular/Structured → ANN
├── Images/Spatial → CNN
└── Sequential
    ├── Short sequences → RNN
    ├── Long sequences + accuracy critical → LSTM
    └── Long sequences + speed critical → GRU
```

### 📊 **Common Hyperparameters**
- **Learning Rate**: 0.001 (Adam), 0.01 (SGD)
- **Batch Size**: 32, 64, 128
- **Hidden Layers**: 1-3 for simple tasks
- **Hidden Units**: 64, 128, 256
- **Dropout**: 0.2-0.5 for regularization
- **Epochs**: 10-100 depending on data size

### 🛠️ **PyTorch Essentials**
```python
# Basic training loop
for epoch in range(epochs):
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

# Key methods
model.train()    # Training mode
model.eval()     # Evaluation mode
torch.no_grad()  # Disable gradients for inference
```

### 🎓 **Exam Tips**
1. **Understand the mathematical formulas** behind each architecture
2. **Know when to use each model type** based on data characteristics
3. **Remember the vanishing gradient problem** and how LSTM/GRU solve it
4. **Practice implementing** basic versions from scratch
5. **Understand the role of gates** in LSTM/GRU
6. **Know common hyperparameters** and their typical ranges

---
**💡 Good luck with your exam! This notebook covers all the essential neural network implementations you need to know.**