In [5]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras 
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
import time

In [6]:
class ANN:
    def __init__(self, input_size=784, hidden_sizes=[512, 256], output_size=10):
        self.model = self.build_model(input_size, hidden_sizes, output_size)
    
    def build_model(self, input_size, hidden_sizes, output_size):
        model = models.Sequential()
        model.add(layers.Dense(hidden_sizes[0], activation='relu', input_shape=(input_size,)))
        model.add(layers.Dropout(0.2))
        for hidden_size in hidden_sizes[1:]:
            model.add(layers.Dense(hidden_size, activation='relu'))
            model.add(layers.Dropout(0.2))
        model.add(layers.Dense(output_size, activation='softmax'))
        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        return model
    
    def train(self, X_train, y_train, X_val, y_val, epochs=15, batch_size=128):
        start_time = time.time()
        history = self.model.fit(
            X_train, y_train, validation_data=(X_val, y_val),
            epochs=epochs, batch_size=batch_size, verbose=1
        )
        training_time = time.time() - start_time
        print(f"\nANN Training time: {training_time:.2f} seconds")
        return history, training_time
    
    def evaluate(self, X_test, y_test):
        loss, accuracy = self.model.evaluate(X_test, y_test, verbose=0)
        return loss, accuracy
    
    def predict(self, X):
        return np.argmax(self.model.predict(X, verbose=0), axis=1)

In [7]:
class ANNtoSNN:
    def __init__(self, ann_model, timesteps=100):
        self.timesteps = timesteps
        self.weights = []
        self.biases = []
        
        # Extract weights from Dense layers only
        for layer in ann_model.layers:
            if isinstance(layer, layers.Dense):
                w, b = layer.get_weights()
                self.weights.append(w)
                self.biases.append(b)
        
        print(f"Extracted {len(self.weights)} layers from ANN")
        for i, (w, b) in enumerate(zip(self.weights, self.biases)):
            print(f"  Layer {i}: Weight shape {w.shape}, Bias shape {b.shape}")

    def forward(self, x):
        """Forward pass through SNN with rate-based encoding"""
        # Rate-based input encoding (normalized pixel values as firing rates)
        current_activity = x.copy()
        
        # Process through all layers except the last one
        for i in range(len(self.weights) - 1):
            # Weighted sum
            activation = np.dot(current_activity, self.weights[i]) + self.biases[i]
            # ReLU activation (mimics spiking behavior)
            current_activity = np.maximum(0, activation)
            # Normalize to keep values in reasonable range
            if np.max(current_activity) > 0:
                current_activity = current_activity / np.max(current_activity)
        
        # Output layer - no normalization
        output_activation = np.dot(current_activity, self.weights[-1]) + self.biases[-1]
        
        return output_activation

    def predict(self, x):
        """Predict single sample"""
        output = self.forward(x)
        return np.argmax(output)
    
    def predict_batch(self, X):
        """Predict multiple samples"""
        predictions = []
        for i in range(len(X)):
            predictions.append(self.predict(X[i]))
        return np.array(predictions)


In [8]:
class ANNtoSNN_Spiking:
    
    def __init__(self, ann_model, timesteps=100, threshold=1.0):
        self.timesteps = timesteps
        self.threshold = threshold
        self.weights = []
        self.biases = []
        
        
        for layer in ann_model.layers:
            if isinstance(layer, layers.Dense):
                w, b = layer.get_weights()
                self.weights.append(w)
                self.biases.append(b)
        
        print(f"\nSpiking SNN: Extracted {len(self.weights)} layers")

    def poisson_encoding(self, x, timesteps):
        
        spikes = np.random.rand(timesteps, len(x)) < np.clip(x, 0, 1)
        return spikes.astype(float)

    def forward(self, x):
        
        spike_train = self.poisson_encoding(x, self.timesteps)
        
        
        n_layers = len(self.weights)
        membrane_potentials = [np.zeros(w.shape[1]) for w in self.weights]
        spike_counts = [np.zeros(w.shape[1]) for w in self.weights]
        
        
        for t in range(self.timesteps):
            current_spikes = spike_train[t]
            
            
            for layer_idx in range(n_layers):
                
                weighted_input = np.dot(current_spikes, self.weights[layer_idx])
                membrane_potentials[layer_idx] += weighted_input
                
               
                membrane_potentials[layer_idx] += self.biases[layer_idx] / self.timesteps
                
               
                membrane_potentials[layer_idx] = np.maximum(0, membrane_potentials[layer_idx])
                
                
                spikes = (membrane_potentials[layer_idx] >= self.threshold).astype(float)
                spike_counts[layer_idx] += spikes
                
               
                membrane_potentials[layer_idx] = np.where(
                    spikes > 0,
                    0,  
                    membrane_potentials[layer_idx]
                )
                
               
                current_spikes = spikes
        
     
        return spike_counts[-1]

    def predict(self, x):
      
        spike_counts = self.forward(x)
        return np.argmax(spike_counts)
    
    def predict_batch(self, X):
       
        predictions = []
        for i in range(len(X)):
            if i % 100 == 0:
                print(f"  Processing sample {i}/{len(X)}...", end='\r')
            predictions.append(self.predict(X[i]))
        print()
        return np.array(predictions)


In [9]:
def plot_training_history(ann_history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
  
    ax1.plot(ann_history.history['accuracy'], label='Train Acc', marker='o', linewidth=2)
    ax1.plot(ann_history.history['val_accuracy'], label='Val Acc', marker='s', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Accuracy', fontsize=12)
    ax1.set_title('ANN Training Accuracy', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    

    ax2.plot(ann_history.history['loss'], label='Train Loss', marker='o', linewidth=2)
    ax2.plot(ann_history.history['val_loss'], label='Val Loss', marker='s', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Loss', fontsize=12)
    ax2.set_title('ANN Training Loss', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('ann_training_history.png', dpi=300, bbox_inches='tight')
    plt.show()



In [10]:
def plot_accuracy_comparison(results):
    fig, ax = plt.subplots(figsize=(10, 6))
    
    models = list(results.keys())
    accuracies = list(results.values())
    colors = ['#3498db', '#e74c3c', '#2ecc71']
    
    bars = ax.bar(models, accuracies, color=colors[:len(models)], 
                  alpha=0.7, edgecolor='black', linewidth=2)
    
    ax.set_ylabel('Accuracy', fontsize=14)
    ax.set_ylim([0, 1])
    ax.set_title('Model Accuracy Comparison', fontsize=16, fontweight='bold')
    ax.grid(axis='y', alpha=0.3)
    
  
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{acc:.4f}',
                ha='center', va='bottom', fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('accuracy_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()



In [11]:
def plot_sample_predictions(X_test, y_test, ann_pred, snn_pred, n_samples=10):
    fig, axes = plt.subplots(2, n_samples, figsize=(20, 5))
    indices = np.random.choice(len(X_test), n_samples, replace=False)
    
    for idx, i in enumerate(indices):
        img = X_test[i].reshape(28, 28)
        
    
        axes[0, idx].imshow(img, cmap='gray')
        axes[0, idx].axis('off')
        color = 'green' if ann_pred[i] == y_test[i] else 'red'
        axes[0, idx].set_title(f'ANN: {ann_pred[i]}\nTrue: {y_test[i]}', 
                               color=color, fontweight='bold', fontsize=10)
        
      
        axes[1, idx].imshow(img, cmap='gray')
        axes[1, idx].axis('off')
        color = 'green' if snn_pred[i] == y_test[i] else 'red'
        axes[1, idx].set_title(f'SNN: {snn_pred[i]}\nTrue: {y_test[i]}', 
                               color=color, fontweight='bold', fontsize=10)
    
    fig.text(0.02, 0.75, 'ANN', fontsize=14, fontweight='bold', rotation=90, va='center')
    fig.text(0.02, 0.25, 'SNN', fontsize=14, fontweight='bold', rotation=90, va='center')
    
    plt.tight_layout()
    plt.savefig('sample_predictions.png', dpi=300, bbox_inches='tight')
    plt.show()

In [12]:
def main():
    print("="*70)
    print("ANN to SNN Conversion - MNIST Digit Classification")
    print("="*70)
    
  
    print("\n[1] Loading MNIST dataset...")
    (X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()
    X_train = X_train.reshape(-1, 784).astype('float32') / 255.0
    X_test = X_test.reshape(-1, 784).astype('float32') / 255.0
    y_train_cat = to_categorical(y_train, 10)
    y_test_cat = to_categorical(y_test, 10)

  
    n_train, n_test = 20000, 2000
    X_train_subset, y_train_subset = X_train[:n_train], y_train[:n_train]
    y_train_cat_subset = y_train_cat[:n_train]
    X_test_subset, y_test_subset = X_test[:n_test], y_test[:n_test]
    y_test_cat_subset = y_test_cat[:n_test]

    print(f"Training samples: {n_train}, Test samples: {n_test}")

   
    print("\n[2] Training ANN...")
    ann = ANN(input_size=784, hidden_sizes=[256, 128], output_size=10)
    ann_history, ann_time = ann.train(
        X_train_subset, y_train_cat_subset, 
        X_test_subset, y_test_cat_subset, 
        epochs=10, batch_size=128
    )
    
  
    print("\n[3] Evaluating ANN...")
    ann_loss, ann_accuracy = ann.evaluate(X_test_subset, y_test_cat_subset)
    ann_predictions = ann.predict(X_test_subset)
    print(f"✓ ANN Test Accuracy: {ann_accuracy:.4f}")

  
    print("\n[4] Converting ANN to Rate-based SNN...")
    snn_rate = ANNtoSNN(ann.model, timesteps=100)
    
    print("Evaluating Rate-based SNN...")
    snn_rate_predictions = snn_rate.predict_batch(X_test_subset)
    snn_rate_accuracy = np.mean(snn_rate_predictions == y_test_subset)
    print(f"✓ Rate-based SNN Test Accuracy: {snn_rate_accuracy:.4f}")

  
    print("\n[5] Converting ANN to Spiking SNN...")
    
   
    timesteps_to_use = 100
    threshold_to_use = 1.0  
    
    snn_spike = ANNtoSNN_Spiking(ann.model, timesteps=timesteps_to_use, threshold=threshold_to_use)
    
    print(f"Configuration: timesteps={timesteps_to_use}, threshold={threshold_to_use}")
    print("Evaluating Spiking SNN (this may take a moment)...")
    snn_spike_predictions = snn_spike.predict_batch(X_test_subset)
    snn_spike_accuracy = np.mean(snn_spike_predictions == y_test_subset)
    print(f"✓ Spiking SNN Test Accuracy: {snn_spike_accuracy:.4f}")
    
    
    if snn_spike_accuracy < 0.75:
        print("\n If accuracy is too low, try threshold=0.5 or timesteps=150")
    elif snn_spike_accuracy >= 0.80:
        print("\n✓ Good SNN conversion! Spiking accuracy within expected range.")

    
    print("\n" + "="*70)
    print("RESULTS SUMMARY")
    print("="*70)
    print(f"ANN Accuracy:              {ann_accuracy:.4f}")
    print(f"Rate-based SNN Accuracy:   {snn_rate_accuracy:.4f} (drop: {(ann_accuracy - snn_rate_accuracy):.4f})")
    print(f"Spiking SNN Accuracy:      {snn_spike_accuracy:.4f} (drop: {(ann_accuracy - snn_spike_accuracy):.4f})")
    print(f"Training Time:             {ann_time:.2f} seconds")
    print("="*70)

 
  
    plot_training_history(ann_history)
    
    results = {
        'ANN': ann_accuracy,
        'Rate SNN': snn_rate_accuracy,
        'Spike SNN': snn_spike_accuracy
    }
    plot_accuracy_comparison(results)
    plot_sample_predictions(X_test_subset, y_test_subset, 
                           ann_predictions, snn_rate_predictions, n_samples=10)

In [None]:
if __name__ == "__main__":
    main()

ANN to SNN Conversion - MNIST Digit Classification

[1] Loading MNIST dataset...
Training samples: 20000, Test samples: 2000

[2] Training ANN...
Epoch 1/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 8ms/step - accuracy: 0.7103 - loss: 0.9553 - val_accuracy: 0.9160 - val_loss: 0.2901
Epoch 2/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 6ms/step - accuracy: 0.9252 - loss: 0.2511 - val_accuracy: 0.9315 - val_loss: 0.2159
Epoch 3/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 6ms/step - accuracy: 0.9470 - loss: 0.1828 - val_accuracy: 0.9460 - val_loss: 0.1785
Epoch 4/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 6ms/step - accuracy: 0.9641 - loss: 0.1236 - val_accuracy: 0.9490 - val_loss: 0.1703
Epoch 5/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 6ms/step - accuracy: 0.9699 - loss: 0.1012 - val_accuracy: 0.9515 - val_loss: 0.1463
Epoch 6/10
[1m157/157[0m [32m━━━━━━━━━━━━