In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import norse.torch as norse
import numpy as np
# Hyperparameters
batch_size = 64
epochs = 10
learning_rate = 0.01



In [2]:
def get_data_loader():
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader







In [3]:
class SNN(nn.Module):
    def __init__(self):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.lif1 = norse.LIFCell()
        self.fc2 = nn.Linear(256, 10)
        self.lif2 = norse.LIFCell()
    

    
    def forward(self, x):
        batch_size, _, _, _ = x.shape
        x = x.view(batch_size, -1)  # Flatten the input
        spk, mem = self.lif1(self.fc1(x))
        spk, mem = self.lif2(self.fc2(spk))
        return spk



In [16]:
# Hebbian Learning Rule
class HebbianLearning:
    def __init__(self, lr=0.01):
        self.lr = lr
    
    def update_weights(self, weights, inputs, outputs):
        if isinstance(inputs, norse.LIFFeedForwardState):
            inputs = inputs.v  # Extract membrane potential if LIFFeedForwardState
        if isinstance(outputs, norse.LIFFeedForwardState):
            outputs = outputs.v  # Extract membrane potential if LIFFeedForwardState
        
        # Ensure tensors have correct dimensions
        batch_size = inputs.shape[0]
        inputs = inputs.view(batch_size, -1)  # Reshape inputs (64, 784)
        outputs = outputs.view(batch_size, -1)  # Reshape outputs (64, 256)
        
        # Compute batch-wise Hebbian weight update
        delta_w = self.lr * torch.mm(outputs.T, inputs) / batch_size  # (256, 784)
        
        # Ensure weights match the expected shape (256, 784)
        if weights.shape != delta_w.shape:
            raise RuntimeError(f"Shape mismatch: weights={weights.shape}, delta_w={delta_w.shape}")
        
        return weights + delta_w

In [17]:
# Transformer-Based Optimization
class TransformerOptimizer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(TransformerOptimizer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=2)
        self.fc = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        x = x.unsqueeze(1)  # Add sequence dimension
        attn_output, _ = self.attention(x, x, x)
        return self.fc(attn_output.squeeze(1))



In [18]:
# Modified SNN with Bio-Transformer Learning (BTL)
class BTL_SNN(nn.Module):
    def __init__(self):
        super(BTL_SNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.lif1 = norse.LIFCell()
        self.hebbian = HebbianLearning(lr=0.005)
        self.transformer = TransformerOptimizer(256, 256)
        self.fc2 = nn.Linear(256, 10)
        self.lif2 = norse.LIFCell()
    
    def forward(self, x):
        batch_size, _, _, _ = x.shape
        x = x.view(batch_size, -1)  # Flatten the input
        x = self.fc1(x)
        spk, mem = self.lif1(x)
        x = self.hebbian.update_weights(x, spk, mem)
        x = self.transformer(x)
        spk, mem = self.lif2(self.fc2(x))
        return spk



In [19]:
# Training function
def train(model, train_loader, optimizer, criterion):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")



In [21]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models and training parameters
snn_model = BTL_SNN().to(device)
optimizer = optim.Adam(snn_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Load data and train
train_loader, test_loader = get_data_loader()
train(snn_model, train_loader, optimizer, criterion)


RuntimeError: Shape mismatch: weights=torch.Size([64, 256]), delta_w=torch.Size([256, 256])