In [1]:
pip install pytorch_spiking

Collecting pytorch_spiking
  Downloading pytorch_spiking-0.1.0-py3-none-any.whl (10 kB)
Installing collected packages: pytorch_spiking
Successfully installed pytorch_spiking-0.1.0


In [2]:
pip install --upgrade torch torchviz graphviz


Collecting torch
  Downloading torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchviz
  Downloading torchviz-0.0.2.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m65.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m66.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylin

In [4]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import time
import pytorch_spiking

torch.manual_seed(0)
np.random.seed(0)


In [5]:
class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.query = torch.nn.Linear(input_dim, input_dim)
        self.key = torch.nn.Linear(input_dim, input_dim)
        self.value = torch.nn.Linear(input_dim, input_dim)

    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attn_weights = torch.nn.functional.softmax(q @ k.transpose(-2, -1), dim=-1)
        return attn_weights @ v

In [6]:
spikeaware_model = torch.nn.Sequential(
     torch.nn.Linear(784, 256),   #13 layers
     SelfAttention(256),
     torch.nn.SELU(),
    # set spiking_aware_training and a moderate dt
    pytorch_spiking.SpikingActivation(
        torch.nn.ELU(alpha=1.0), dt=0.5, spiking_aware_training=True #exponential linear unit
    ),
    torch.nn.Linear(256,128),
    SelfAttention(128),
    torch.nn.GELU(),
     torch.nn.Dropout(0.4),
       pytorch_spiking.SpikingActivation(
        torch.nn.ELU(alpha=1.0), dt=0.8, spiking_aware_training=True #exponential linear unit
    ),
    torch.nn.Linear(128,64),
    torch.nn.Dropout(0.4),
    pytorch_spiking.TemporalAvgPool(),
    torch.nn.Linear(64, 10),
)


In [9]:
pip install --upgrade torch torchvision torchaudio

Collecting torchvision
  Downloading torchvision-0.16.1-cp310-cp310-manylinux1_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m41.5 MB/s[0m eta [36m0:00:00[0m
Collecting torchaudio
  Downloading torchaudio-2.1.1-cp310-cp310-manylinux1_x86_64.whl (3.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m72.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchvision, torchaudio
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.16.0+cu118
    Uninstalling torchvision-0.16.0+cu118:
      Successfully uninstalled torchvision-0.16.0+cu118
  Attempting uninstall: torchaudio
    Found existing installation: torchaudio 2.1.0+cu118
    Uninstalling torchaudio-2.1.0+cu118:
      Successfully uninstalled torchaudio-2.1.0+cu118
Successfully installed torchaudio-2.1.1 torchvision-0.16.1


In [10]:
pip install --upgrade torchviz



In [14]:

import torch

# Assuming you have defined the spikeaware_model
sample_input = torch.randn((1, 784)).unsqueeze(0)  # Add an extra dimension to create a 3D tensor
trace_model = torch.jit.trace(spikeaware_model, sample_input)


# Create a visualization of the model architecture
dot = make_dot(trace_model(sample_input), params=dict(spikeaware_model.named_parameters()))

# Save the visualization as a PNG file
dot.format = 'png'
dot.render('spikeaware_model_architecture', format='png', cleanup=True)


RuntimeError: ignored