In [14]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
# # InteractiveShell.ast_node_interactivity = "last_expr"

# 训练 

In [15]:
from spikingjelly.activation_based import surrogate
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.notebook import tqdm  # tqdm 进度条显示
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import sys
from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
device

device(type='cuda')

In [3]:
class SNN(nn.Module):
    def __init__(self, tau = 2.0):
        super().__init__()

        self.layer = nn.Sequential(
            layer.Flatten(),
            layer.Linear(28 * 28, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
            )

    def forward(self, x: torch.Tensor):
        return self.layer(x)

In [4]:
bs = 512 #
train_epoch = 5
T = 5 

net = SNN().to(device)

# 使用Mnist数据集
train_set = datasets.MNIST("../data",train=True,download=False, transform=transforms.ToTensor(),)
test_set = datasets.MNIST("../data",train=False,download=False, transform=transforms.ToTensor(),)

train_loader = DataLoader(dataset=train_set, batch_size=bs, shuffle=True,drop_last = True)
test_loader = DataLoader(dataset=test_set, batch_size=bs, shuffle=False,drop_last = True)

encoder = encoding.PoissonEncoder()
optimizer = optim.Adam(net.parameters(), lr=0.01)
loss_function = F.mse_loss

best_acc = 0
train_accs = []
train_loss = []
test_accs = []
test_loss = []

In [None]:
for epoch in range(train_epoch):
    _ = net.train()
    sums = 0
    accuracys = 0
    los = 0
#     datas = tqdm(iter(train_loader),file=sys.stdout)
    for img, label in tqdm(train_loader):
#     for img, label in datas:
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(T):
            encoded_img = encoder(img)
            out_fr += net(encoded_img)
        out_fr = out_fr / T
        # 损失函数
        loss = F.mse_loss(out_fr, label_onehot)
        loss.backward()
        # 优化器
        optimizer.step()
        # 本次测试的总样本数量
        sums += label.numel()
        los += loss.item() * label.numel()
        accuracys += (out_fr.argmax(1) == label).float().sum().item()
        
        # 优化一次参数后，需要重置网络的状态，因为SNN的神经元是有“记忆”的。--------------------这一步必须有-------------------------
        functional.reset_net(net)
    
    # 计算这一次训练总的准确率和误差
    temp = round(accuracys / sums * 100,2)
    temp2 = round(los / sums,4)
    print(f"第{epoch+1}次迭代的整体准确率为{temp}%, 平均损失为{temp2}")
    train_accs.append(temp)
    train_loss.append(temp2)
    
    # 查看网络在测试集上的准确率与损失
    _ = net.eval()
    with torch.no_grad():
        sums = 0
        accuracys = 0
        los = 0
        for img, label in tqdm(test_loader):
            img = img.to(device)
            label = label.to(device)
            label_onehot = F.one_hot(label, 10).float()
            for t in range(T):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / T
            loss = loss_function(out_fr, label_onehot)
            # 本次测试的总样本数量
            sums += label.numel()
            los += loss.item() * label.numel()
            accuracys += (out_fr.argmax(1) == label).float().sum().item()
        temp = round(accuracys / sums * 100,2)
        temp2 = round(los / sums,4)
        print(f"网络在测试上的整体准确率为{temp}%, 平均损失为{temp2}")
        test_accs.append(temp)
        test_loss.append(temp2)
        
    if best_acc < temp:
        torch.save(net.state_dict(),'save_models/snn.pth' )
        best_acc = temp
        print("update weight")

# 推理 

In [None]:
from spikingjelly.activation_based import surrogate
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.notebook import tqdm  # tqdm 进度条显示
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import sys
from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer
device = torch.device('cpu')

In [None]:
class SNN(nn.Module):
    def __init__(self, tau = 2.0):
        super().__init__()

        self.layer = nn.Sequential(
            layer.Flatten(),
            layer.Linear(28 * 28, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
            )

    def forward(self, x: torch.Tensor):
        return self.layer(x)

In [None]:
bs = 512 #
train_epoch = 20
T = 5 
# 使用Mnist数据集
train_set = datasets.MNIST("../data",train=True,download=False, transform=transforms.ToTensor(),)
test_set = datasets.MNIST("../data",train=False,download=False, transform=transforms.ToTensor(),)

train_loader = DataLoader(dataset=train_set, batch_size=bs, shuffle=True,drop_last = True)
test_loader = DataLoader(dataset=test_set, batch_size=bs, shuffle=False,drop_last = True)

encoder = encoding.PoissonEncoder()
loss_function = F.mse_loss

best_acc = 0
train_accs = []
train_loss = []
test_accs = []
test_loss = []

In [None]:
test_net = SNN().to(device)
test_net.load_state_dict(torch.load('save_models/snn.pth' ))

In [None]:
acc, los = evaluates("test", test_net, test_loader, T)

# 使用混合精度进行训练

In [12]:
from spikingjelly.activation_based import surrogate
import torch
from torch import amp
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.notebook import tqdm  # tqdm 进度条显示
import torch.nn.functional as F
import torch.optim as optim
import time
import sys
import matplotlib.pyplot as plt
from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
T=20
batch_size=512
lr=1e-3
num_epochs = 20

class SNN(nn.Module):
    def __init__(self, tau = 2.0):
        super().__init__()

        self.layer = nn.Sequential(
            layer.Flatten(),
            layer.Linear(28 * 28, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
            )

    def forward(self, x: torch.Tensor):
        return self.layer(x)

# 设备选择使用GPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# 初始化网络
net = SNN().to(device)
# 使用Adam优化器
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# 使用泊松编码器
encoder = encoding.PoissonEncoder()

In [8]:
data_path='../data'

# Define a transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
#             transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))
            ])

mnist_train = datasets.MNIST(data_path, train=True, download=False, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=False, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
def evaluates(name, model, data_loader, T):
    """
    测试函数，没有什么特殊之处
    """
    acc = 0
    losss = 0
    data_num = 0
    _ = model.eval()

    datas = tqdm(iter(data_loader),file=sys.stdout)
    for data, targets in datas:
        data = data.to(device)
        targets = targets.to(device)
        label_onehot = F.one_hot(label, 10).float()  # 将标签变为独热编码
        
        out_fr = 0.
        for t in range(T):
            encoded_img = encoder(data)
            out_fr += model(encoded_img) # 将每次输出的脉冲进行相加
        out_fr = out_fr / T # 求平均值
        
        loss = F.mse_loss(out_fr, label_onehot)
        data_num += targets.numel()
        losss += loss.item() * targets.numel()
        # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
        acc += (out_fr.argmax(1) == targets).float().sum().item()
    
    print(f"{name} All Acc: {acc/data_num * 100:.2f}%")
    print(f"{name} All Loss: {losss/data_num:.2f}")
#     print("\n")
    return acc/data_num,losss/data_num

In [13]:
# scaler = None
# 使用混合精度
scaler = amp.GradScaler()
train_accs = []
train_losss = []
test_accs = []
test_losss = []
best_acc = 0

In [None]:
for epoch in range(num_epochs):
    # 打印每次训练的时间，此处是开始时间
    start_time = time.time()
    _ = net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    
    datas = tqdm(iter(train_loader),file=sys.stdout)
    for img, label in datas:
        optimizer.zero_grad()
        img = img.to(device)
        label = label.to(device)
        label_onehot = F.one_hot(label, 10).float()  # 将标签变为独热编码

        # 混合精度训练
        if scaler is not None:
            with amp.autocast():
                out_fr = 0.
                # 运行T个时间步
                for t in range(T):
                    encoded_img = encoder(img)
                    out_fr += net(encoded_img) # 将每次输出的脉冲进行相加
                out_fr = out_fr / T # 求平均值
                # out_fr是shape=[batch_size, 10]的tensor
                # 记录整个仿真时长内，输出层的10个神经元的脉冲发放率
                loss = F.mse_loss(out_fr, label_onehot)
                # 损失函数为输出层神经元的脉冲发放频率，与真实类别的MSE
                # 这样的损失函数会使得：当标签i给定时，输出层中第i个神经元的脉冲发放频率趋近1，而其他神经元的脉冲发放频率趋近0
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out_fr = 0.
            for t in range(T):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / T
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        # 正确率的计算方法如下。  认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        # 优化一次参数后，需要重置网络的状态，因为SNN的神经元是有“记忆”的。
        functional.reset_net(net)
    # 打印每次训练的时间，此处是结束时间
    end_time = time.time()
    print(f"epoch = {epoch}")
    print(f"train_single_time = {end_time - start_time:.4f}")
    print(f"loss = {train_loss/train_samples:.4f}")
    print(f"acc = {train_acc/train_samples*100:.2f}%")
    train_accs.append(train_acc/train_samples)
    train_losss.append(train_loss/train_samples)
    
    test_acc,test_loss = evaluates("test", net, test_loader, T)
    test_accs.append(test_acc)
    test_losss.append(test_loss)
    
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(net.state_dict(), 'save_models/snn2.pt')
        print("save model param")
    
    print("\n")  

# 推理 

In [None]:
test_net2 = SNN().to(device)
test_net2.load_state_dict(torch.load('save_models/snn2.pth' ))
acc, los = evaluates("test2", test_net2, test_loader, T)