In [2]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

# Hyperparameters
num_epochs = 10
sequence_length = 100

# Load FashionMNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1, shuffle=True)

test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

# Define your SNN model
snn_model = torch.nn.Sequential(
    torch.nn.Linear(784, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10),
)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(snn_model.parameters(), lr=0.001)

# Training loop
for epoch in range(num_epochs):
    correct_train = 0
    total_train = 0
    for i, (images, labels) in enumerate(train_loader):
        # Reset the state of the SNN for each sequence
        snn_model.zero_grad()

        # Forward pass (FPTT)
        for t in range(sequence_length):
            # Flatten the input image
            images = images.view(-1, 784)

            # Forward pass through the SNN

            spikes = snn_model(images)


        # Compute loss and perform backpropagation
        loss = criterion(spikes, labels)
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(spikes.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    train_accuracy = 100 * correct_train / total_train
    print(f'Training Accuracy after Epoch {epoch + 1}: {train_accuracy:.2f}%')

# Test the model
snn_model.eval()
correct_test = 0
total_test = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(-1, 784)
        spikes = snn_model(images)
        _, predicted = torch.max(spikes.data, 1)
        total_test += labels.size(0)
        correct_test += (predicted == labels).sum().item()

test_accuracy = 100 * correct_test / total_test
print(f'Test Accuracy: {test_accuracy:.2f}%')

# Save the trained model to a file
torch.save(snn_model.state_dict(), 'my_snn_model2.pth')


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch [2/10], Step [41100/60000], Loss: 0.0000
Epoch [2/10], Step [41200/60000], Loss: 0.1264
Epoch [2/10], Step [41300/60000], Loss: 1.3054
Epoch [2/10], Step [41400/60000], Loss: 2.9510
Epoch [2/10], Step [41500/60000], Loss: 0.4922
Epoch [2/10], Step [41600/60000], Loss: 3.0933
Epoch [2/10], Step [41700/60000], Loss: 0.5044
Epoch [2/10], Step [41800/60000], Loss: 0.9021
Epoch [2/10], Step [41900/60000], Loss: 0.0087
Epoch [2/10], Step [42000/60000], Loss: 0.7057
Epoch [2/10], Step [42100/60000], Loss: 0.4995
Epoch [2/10], Step [42200/60000], Loss: 2.5985
Epoch [2/10], Step [42300/60000], Loss: 0.0000
Epoch [2/10], Step [42400/60000], Loss: 0.0014
Epoch [2/10], Step [42500/60000], Loss: 0.4796
Epoch [2/10], Step [42600/60000], Loss: 0.0834
Epoch [2/10], Step [42700/60000], Loss: 0.2915
Epoch [2/10], Step [42800/60000], Loss: 0.1588
Epoch [2/10], Step [42900/60000], Loss: 0.0013
Epoch [2/10], Step [43000/60000], Loss: 0.