In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Transformation : Normalisation simple [0,1]
transform = transforms.Compose([
    transforms.ToTensor()
])

# Chargement du dataset MNIST
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Dataloader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 485kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.34MB/s]


In [2]:
def poisson_encode(img, time_steps=20):
    """ Encode une image en train de Poisson binaire pour SNN """
    img = img.view(-1, 28, 28)
    spike_train = torch.rand((time_steps, *img.shape)) < img.unsqueeze(0)
    return spike_train.float()


In [3]:
sample_img, _ = train_dataset[0]
encoded_spike = poisson_encode(sample_img, time_steps=20)
print(encoded_spike.shape)  # (time_steps, 1, 28, 28)

torch.Size([20, 1, 28, 28])


In [4]:
import torch.nn as nn

class SparseSNN(nn.Module):
    def __init__(self, input_size=28*28, hidden_size=500, output_size=10, sparsity=0.8):
        super(SparseSNN, self).__init__()

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

        self.sparsity = sparsity

    def forward(self, x_spike_train):
        batch_size = x_spike_train.shape[1]
        mem = torch.zeros(batch_size, self.fc2.out_features, device=x_spike_train.device)

        for t in range(x_spike_train.shape[0]):
            x = x_spike_train[t].view(batch_size, -1)

            # Ajout de sparsité artificielle
            mask = (torch.rand_like(x) > self.sparsity).float()
            x = x * mask

            hidden = torch.relu(self.fc1(x))
            output = self.fc2(hidden)

            mem += output

        return mem / x_spike_train.shape[0]


In [5]:
def stdp_update(pre_spikes, post_spikes, weights, lr=1e-3):
    """ Mise à jour locale des poids avec règle Hebbienne simplifiée """
    delta_w = torch.bmm(post_spikes.unsqueeze(2), pre_spikes.unsqueeze(1))
    weights.data += lr * delta_w.mean(dim=0)
    weights.data = torch.clamp(weights.data, -1.0, 1.0)


In [6]:
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SparseSNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

num_epochs = 5
time_steps = 20

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0

    start_time = time.time()

    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        encoded = poisson_encode(imgs, time_steps=time_steps).to(device)
        outputs = model(encoded)

        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()

    end_time = time.time()

    print(f'Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f} - Accuracy: {correct/len(train_dataset):.4f} - Time: {end_time-start_time:.2f}s')


Epoch 1/5 - Loss: 2123.9198 - Accuracy: 0.3773 - Time: 51.57s
Epoch 2/5 - Loss: 1972.7196 - Accuracy: 0.5776 - Time: 49.59s
Epoch 3/5 - Loss: 1589.4377 - Accuracy: 0.6837 - Time: 49.62s
Epoch 4/5 - Loss: 1144.2993 - Accuracy: 0.7535 - Time: 52.22s
Epoch 5/5 - Loss: 866.7774 - Accuracy: 0.7961 - Time: 49.46s


In [7]:
from datetime import datetime

model.eval()
test_correct = 0
total_energy = 0
total_time = 0

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        start = time.perf_counter()
        encoded = poisson_encode(imgs, time_steps=time_steps).to(device)
        outputs = model(encoded)
        end = time.perf_counter()

        preds = outputs.argmax(dim=1)
        test_correct += (preds == labels).sum().item()

        total_time += (end - start)
        # Simulation grossière de consommation énergétique
        total_energy += (end - start) * 0.5  # (arbitraire : 0.5W utilisation estimée)

print(f'Test Accuracy: {test_correct/len(test_dataset):.4f}')
print(f'Test Inference Time (total): {total_time:.2f}s')
print(f'Estimated Energy Consumed: {total_energy:.4f} Joules')


Test Accuracy: 0.8165
Test Inference Time (total): 4.38s
Estimated Energy Consumed: 2.1878 Joules


**TEST AVEC LES 3 DATASETS **

In [8]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Définition d'une fonction pour charger le dataset
def load_dataset(name, batch_size=64):
    """
    Charge le dataset spécifié (MNIST, Fashion-MNIST ou CIFAR-10)
    """
    # Transformations : convertit en Tensor et normalise entre 0 et 1
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    if name == 'MNIST':
        train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    elif name == 'FashionMNIST':
        train_set = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_set = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif name == 'CIFAR10':
        # Pour CIFAR-10 : normaliser couleurs entre 0 et 1
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),  # On passe en niveaux de gris pour être cohérent
            transforms.Resize((28, 28)),  # Redimensionne à 28x28
            transforms.ToTensor()
        ])
        train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    else:
        raise ValueError("Dataset non supporté : Choisir MNIST, FashionMNIST ou CIFAR10")

    # Chargement sous forme de DataLoader
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


In [9]:
def poisson_encode(images, time_steps=20):
    """
    Transforme des images continues en train d'impulsions binaires (spikes) sur plusieurs temps
    """
    # images: (batch_size, 1, 28, 28)
    images = images.view(images.size(0), -1)  # aplatissement images en vecteur (batch_size, 784)

    # Génère des spikes selon une distribution de Bernoulli basée sur l'intensité
    spike_train = torch.rand(time_steps, *images.shape).to(images.device) < images.unsqueeze(0)

    # Retourne le train de spikes (time_steps, batch_size, 784)
    return spike_train.float()


In [10]:
import torch.nn as nn

class SparseSNN(nn.Module):
    """
    Modèle simple de réseau de neurones à pics avec sparsité imposée
    """
    def __init__(self, input_size=28*28, hidden_size=500, output_size=10, sparsity=0.8):
        super(SparseSNN, self).__init__()

        # Couche entièrement connectée entre input et hidden
        self.fc1 = nn.Linear(input_size, hidden_size)

        # Couche entre hidden et output
        self.fc2 = nn.Linear(hidden_size, output_size)

        # Taux de sparsité (0.8 => 80% des entrées sont éteintes à chaque temps)
        self.sparsity = sparsity

    def forward(self, spike_train):
        """
        spike_train: (time_steps, batch_size, input_size)
        """
        batch_size = spike_train.shape[1]

        # Accumulateur de potentiel de membrane (sur les sorties)
        membrane_potential = torch.zeros(batch_size, self.fc2.out_features, device=spike_train.device)

        # Boucle temporelle
        for t in range(spike_train.shape[0]):
            x = spike_train[t]  # état à l'instant t
            # Appliquer la sparsité (désactiver aléatoirement des neurones d'entrée)
            mask = (torch.rand_like(x) > self.sparsity).float()
            x_sparse = x * mask

            # Passage dans réseau
            hidden = torch.relu(self.fc1(x_sparse))
            output = self.fc2(hidden)

            # Accumulation du potentiel
            membrane_potential += output

        # Retourne moyenne sur tous les time steps
        return membrane_potential / spike_train.shape[0]


In [11]:
def train(model, train_loader, optimizer, criterion, device, time_steps=20):
    """
    Entraînement du modèle SNN sur un epoch
    """
    model.train()
    total_loss = 0
    correct = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Encodage des images en train de spikes
        spike_train = poisson_encode(images, time_steps=time_steps)

        # Passage du train de spikes dans le réseau
        outputs = model(spike_train)

        # Calcul de la perte
        loss = criterion(outputs, labels)

        # Rétropropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()

    accuracy = correct / len(train_loader.dataset)
    return total_loss, accuracy

def test(model, test_loader, device, time_steps=20):
    """
    Test du modèle SNN
    """
    model.eval()
    correct = 0
    total_inference_time = 0.0
    total_energy = 0.0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            spike_train = poisson_encode(images, time_steps=time_steps)

            start.record()
            outputs = model(spike_train)
            end.record()

            torch.cuda.synchronize()  # Synchroniser les événements CUDA
            inference_time = start.elapsed_time(end) / 1000.0  # en secondes

            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()

            total_inference_time += inference_time
            total_energy += inference_time * 0.5  # Consommation simulée: 0.5 Joules par seconde

    accuracy = correct / len(test_loader.dataset)
    return accuracy, total_inference_time, total_energy


In [None]:
def main(dataset_name='MNIST', epochs=5, batch_size=64, time_steps=20):
    """
    Fonction principale pour entraîner et tester un modèle SNN bio-inspiré
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Chargement des données
    train_loader, test_loader = load_dataset(dataset_name, batch_size=batch_size)

    # Initialisation du modèle
    model = SparseSNN().to(device)

    # Définition de l'optimiseur et de la fonction de perte
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        loss, train_acc = train(model, train_loader, optimizer, criterion, device, time_steps=time_steps)
        print(f"Train Loss: {loss:.4f}, Train Accuracy: {train_acc:.4f}")

    # Test final
    test_acc, inference_time, energy = test(model, test_loader, device, time_steps=time_steps)
    print(f"\n--- Résultats Test sur {dataset_name} ---")
    print(f"Accuracy: {test_acc:.4f}")
    print(f"Temps d'inférence total: {inference_time:.2f} secondes")
    print(f"Énergie estimée: {energy:.2f} Joules")

# Exemple d'utilisation
main('MNIST')
# main('FashionMNIST')
# main('CIFAR10')



Epoch 1/5
