In [17]:
import torch
from torchvision import datasets, transforms

In [18]:
!pip install torch torchvision




In [7]:
!pip install spikingjelly


Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[K     |████████████████████████████████| 437 kB 5.3 MB/s eta 0:00:01
[?25hCollecting tqdm
  Downloading tqdm-4.66.1-py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 2.6 MB/s  eta 0:00:01
Installing collected packages: tqdm, spikingjelly
Successfully installed spikingjelly-0.0.0.0.14 tqdm-4.66.1


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import spikingjelly.clock_driven.encoding as encoding

class LIFNeuron(nn.Module):
    """
    Leaky Integrate-and-Fire (LIF) Neuron Model.

    Parameters:
        - tau_syn (float): The synaptic time constant.
        - v_threshold (float): The threshold potential for generating spikes.
        - v_reset (float): The reset potential after a spike.
        - initial_v (float): The initial membrane potential.
    """
    def __init__(self, tau_syn=5.0, v_threshold=1.0, v_reset=0.0, initial_v=-1.0):
        super(LIFNeuron, self).__init__()
        self.tau_syn = tau_syn
        self.v_threshold = v_threshold
        self.v_reset = v_reset
        self.initial_v = initial_v
        self.synaptic_current = torch.zeros(1)

    def forward(self, synaptic_input):
        """
        Forward pass of the LIF neuron.

        Parameters:
            - synaptic_input (torch.Tensor): The input synaptic current.

        Returns:
            - spike (torch.Tensor): Binary tensor indicating whether a spike occurred.
        """
        self.synaptic_current = torch.exp(-1.0 / self.tau_syn) * self.synaptic_current + synaptic_input
        spike = (self.synaptic_current >= self.v_threshold).float()
        self.synaptic_current = torch.where(spike > 0, torch.tensor(self.v_reset), self.synaptic_current)
        return spike

class SimpleSpikingNetwork(nn.Module):
    """
    Simple Spiking Neural Network with a Poisson input layer, LIF hidden layer, and linear output layer.

    Parameters:
        - input_size (int): Size of the input data.
        - hidden_size (int): Size of the hidden layer.
        - output_size (int): Size of the output layer.
    """
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleSpikingNetwork, self).__init__()
        # Input layer
        self.input_layer = encoding.PoissonEncoder()

        # Hidden layer with LIF neurons
        self.hidden_layer = nn.Sequential(
            LIFNeuron(),
            nn.Linear(input_size, hidden_size),
            nn.ReLU()
        )

        # Output layer
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        """
        Forward pass of the network.

        Parameters:
            - x (torch.Tensor): Input data.

        Returns:
            - output (torch.Tensor): Output of the network.
        """
        # Encode input spikes
        x = self.input_layer(x)

        # Forward pass through hidden layer
        x = self.hidden_layer(x)

        # Output layer
        x = self.output_layer(x)

        return x


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms


# Load train and validation datasets
train_dataset = torch.load("final_project/train_dataset.pth")
val_dataset = torch.load("final_project/val_dataset.pth")

# Assuming CIFAR-10 image size is 32x32 pixels and there are 10 classes
input_size = 128 * 128 * 3  # Assuming RGB images
hidden_size = 128
output_size = 10

# Create an instance of the SimpleSpikingNetwork
model = SimpleSpikingNetwork(input_size, hidden_size, output_size)

# Define the loss function (CrossEntropyLoss for classification)
criterion = nn.CrossEntropyLoss()

# Define the optimizer (Adam optimizer)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for inputs, labels in train_dataset:
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()

    # Validation loop (evaluate model on the validation dataset)
    with torch.no_grad():
        total_correct = 0
        total_samples = 0
        for val_inputs, val_labels in val_dataset:
            val_outputs = model(val_inputs)
            _, predicted = torch.max(val_outputs, 1)
            total_samples += labels.size(0)
            total_correct += (predicted == val_labels).sum().item()

        accuracy = total_correct / total_samples
        print(f'Epoch {epoch + 1}/{num_epochs}, Validation Accuracy: {accuracy:.4f}')

# Save your trained model if needed
torch.save(model.state_dict(), 'trained_model.pth')


FileNotFoundError: [Errno 2] No such file or directory: '/content/CIFAR/frames_number_10_split_by_number/airplane/cifar10_airplane_0.npz'