In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import tqdm



# 定义卷积神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu1(self.conv1(x)))
        x = self.pool(self.relu2(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = self.fc1(x)
        # print("x shape", x.shape)
        x = self.relu(x)
        # print("x shape", x.shape)
        x = self.fc2(x)
        return x


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../.cache/data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:05<00:00, 4.48MB/s]


Extracting ../.cache/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../.cache/data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../.cache/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 151kB/s]


Extracting ../.cache/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../.cache/data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../.cache/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:01<00:00, 2.28MB/s]


Extracting ../.cache/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../.cache/data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../.cache/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 14.9MB/s]

Extracting ../.cache/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../.cache/data/FashionMNIST/raw






In [14]:
from spikingjelly.activation_based import neuron, functional, surrogate, layer
# 初始化网络
model = CNN()

class SpikingModule(nn.Module):
    def __init__(self, T, module):
        super().__init__()
        self.T = T
        self.module = module
        functional.set_step_mode(self, step_mode='m')
    def forward(self, x):
        # print("x shape", x.shape)
        x_seq = x.unsqueeze(0).expand(10, *x.shape)
        # print("x_seq shape", x_seq.shape)

        x_seq = self.module(x_seq)
        fr = x_seq.sum(0)
        # print("fr shape", fr.shape)

        return fr


model.relu = SpikingModule(10, neuron.IFNode(surrogate_function=surrogate.ATan(), step_mode="m",detach_reset=True))


In [15]:
from itertools import cycle
# 定义超参数
batch_size = 64
learning_rate = 0.001

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载数据集
train_dataset = torchvision.datasets.FashionMNIST(root='../.cache/data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.FashionMNIST(root='../.cache/data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练网络
step = 0
max_step = 5000
pbar = tqdm.tqdm(cycle(train_loader), total=max_step)

for images, labels in pbar:
    images = images.to(device)
    labels = labels.to(device)
    if step >= max_step:
        break
    step += 1
    outputs = model(images)
    # print("outputs shape", outputs.shape)
    loss = criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    functional.reset_net(model)
    pbar.set_description(f'Loss: {loss.item():.4f}')
    # break



Loss: 0.0735: 100%|██████████| 5000/5000 [00:50<00:00, 99.27it/s] 


In [16]:
# 测试网络
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        functional.reset_net(model)

    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

Accuracy of the network on the 10000 test images: 89.72 %
