In [None]:
pip install pytorch_spiking

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import time
import pytorch_spiking

torch.manual_seed(0)
np.random.seed(0)

In [None]:
train_pictures, train_labels = zip(
    *torchvision.datasets.FashionMNIST(".", train=True, download=True)
)
train_pictures = np.asarray([np.array(pic) for pic in train_pictures], dtype=np.float32)
train_labels = np.asarray(train_labels, dtype=np.int64)
test_pictures, test_labels = zip(
    *torchvision.datasets.FashionMNIST(".", train=False, download=True)
)
test_pictures = np.asarray([np.array(pic) for pic in train_pictures], dtype=np.float32)
test_labels = np.asarray(train_labels, dtype=np.int64)

# normalize images so values are between 0 and 1
train_pictures = train_pictures / 255.0
test_pictures = test_pictures / 255.0

class_labels = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]
num_classes = len(class_labels)

plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.imshow(train_pictures[i], cmap=plt.cm.binary)
    plt.axis("off")
    plt.title(class_labels[train_labels[i]])

In [None]:
def train(input_ssn, train_x, test_x):
    minibatch_size = 32
    optimizer = torch.optim.LBFGS(input_ssn.parameters())
    s = time.time()
    input_ssn.train()

    # Initialize variables for early stopping
    patience = 5
    best_loss = float('inf')
    best_epoch = 0
    early_stop = False

    for j in range(50):
        train_acc = 0
        for i in range(train_x.shape[0] // minibatch_size):
            input_ssn.zero_grad()

            input_batch = train_x[i * minibatch_size : (i + 1) * minibatch_size]
            # flatten images
            input_batch = input_batch.reshape((-1,) + train_x.shape[1:-2] + (784,))
            batch_label = train_labels[i * minibatch_size : (i + 1) * minibatch_size]
            output = input_ssn(torch.tensor(input_batch))

            # compute sparse categorical cross entropy loss
            logp = torch.nn.functional.log_softmax(output, dim=-1)
            logpy = torch.gather(logp, 1, torch.tensor(batch_label).view(-1, 1))
            loss = -logpy.mean()

            loss.backward()#bptt
            optimizer.step()

            train_acc += torch.mean(
                torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()
            )

        train_acc /= i + 1
        print(f"Accuracy(Training) ({j}): {train_acc.numpy()}")

        # Check if the loss has improved for early stopping
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_epoch = j
        elif j - best_epoch > patience:
            print("Early stopping due to no improvement in loss.")
            early_stop = True
            break

    train_time= time.time() - s
    print("Training time:",train_time)


        # compute test accuracy
    s1=time.time()
    input_ssn.eval()
    test_acc = 0
    for i in range(test_x.shape[0] // minibatch_size):
            input_batch = test_x[i * minibatch_size : (i + 1) * minibatch_size]
            input_batch = input_batch.reshape((-1,) + test_x.shape[1:-2] + (784,))
            batch_label = test_labels[i * minibatch_size : (i + 1) * minibatch_size]
            output = input_ssn(torch.tensor(input_batch))

            test_acc += torch.mean(
                torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()
            )

    test_acc /= i + 1

    print(f"Accuracy(Testing) {test_acc.numpy()}")

In [None]:
# repeat the images for n_steps
n_steps = 10
train_stream = np.tile(train_pictures[:, None], (1, n_steps, 1, 1))
test_stream = np.tile(test_pictures[:, None], (1, n_steps, 1, 1))

In [None]:
class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.query = torch.nn.Linear(input_dim, input_dim)
        self.key = torch.nn.Linear(input_dim, input_dim)
        self.value = torch.nn.Linear(input_dim, input_dim)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attn_weights = torch.nn.functional.softmax(q @ k.transpose(-2, -1), dim=-1)
        return attn_weights @ v

In [None]:
spikeaware_model = torch.nn.Sequential(
     torch.nn.Linear(784, 256),    #13 layers
     SelfAttention(256),
     torch.nn.SELU(),
    # set spiking_aware_training and a moderate dt
    pytorch_spiking.SpikingActivation(
        torch.nn.ELU(alpha=1.0), dt=0.8, spiking_aware_training=True #exponential linear unit
    ),
    torch.nn.Linear(256,128),
    SelfAttention(128),
    torch.nn.GELU(),
     torch.nn.Dropout(0.2),
       pytorch_spiking.SpikingActivation(
        torch.nn.ELU(alpha=1.0), dt=0.8, spiking_aware_training=True #exponential linear unit
    ),
    torch.nn.Linear(128,64),
    torch.nn.Dropout(0.5),
    pytorch_spiking.TemporalAvgPool(),
    torch.nn.Linear(64, 10),
)
train(spikeaware_model, train_stream, test_stream)