In [6]:
import os  
import time  
import sys  
import datetime  
import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torch.cuda import amp  
import torchvision  
import numpy as np  
from torch.utils.tensorboard import SummaryWriter

from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer  

class TransformerSNN(nn.Module):  
    def __init__(self, img_size=28, patch_size=7, in_channels=1, embed_dim=128, num_heads=4, num_classes=10, tau=2.0):  
        super(TransformerSNN, self).__init__()  
        
        self.patch_embed = nn.Sequential(  
            nn.Conv2d(in_channels, embed_dim,   
                      kernel_size=patch_size,   
                      stride=patch_size),  
            layer.BatchNorm2d(embed_dim)  
        )  
         
        num_patches = (img_size // patch_size) ** 2  
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))  
        
        self.transformer_encoder = nn.TransformerEncoder(  
            nn.TransformerEncoderLayer(  
                d_model=embed_dim,   
                nhead=num_heads,  
                dim_feedforward=embed_dim*4,  
                activation=F.relu  
            ),  
            num_layers=2  
        )  
        
        # SNN 分类器  
        self.snn_classifier = nn.Sequential(  
            layer.Linear(16 * 128, 1024),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),  
            layer.Linear(1024, 256, bias=False),  
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),  
            layer.Linear(256, num_classes, bias=False),  
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())  
        )  

    def forward(self, x):  
        x = self.patch_embed(x)  
        B, C, H, W = x.shape  
        x = x.view(B, C, H*W).transpose(1, 2)  
        x += self.pos_embed  
        x = self.transformer_encoder(x)  
        x = x.view(B, -1)  
        return x 
    
def evaluate_accuracy(data_iter, net, encoder, T):  
    acc_sum, n = 0.0, 0
    net.eval()  
    with torch.no_grad():  
        for X, y in data_iter:   
            for t in range(T):
                spike_input = encoder(X)  # Encode to spikes  
                class_output, _ = net(spike_input)  
                if t == 0:
                    out_fr = torch.zeros_like(class_output) 
                out_fr += class_output
            out_fr = out_fr / T
            acc_sum += (out_fr.argmax(1) == y).float().sum().item()  
            n += y.numel()
            functional.reset_net(net)
    return acc_sum / n

def main():  
    class parser:  
        def __init__(self):  
            self.T = 10  
            self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'  
            self.epochs = 10  
            self.b = 128  
            self.j = 4  
            self.data_dir = './mnist_data'  
            self.out_dir = './logs'  
            self.resume = None
            self.amp = True  
            self.opt = 'adam'  
            self.lr = 1e-3  
            self.tau = 2.0  

    args = parser()   

    net = TransformerSNN(tau=args.tau)  
    print(net)  
    net.to(args.device)  

    # 数据加载 MNIST 
    train_dataset = torchvision.datasets.MNIST(  
        root=args.data_dir,  
        train=True,  
        transform=torchvision.transforms.ToTensor(),  
        download=True  
    )  
    test_dataset = torchvision.datasets.MNIST(  
        root=args.data_dir,  
        train=False,  
        transform=torchvision.transforms.ToTensor(),  
        download=True  
    )  

    # 数据加载 FashionMNIST
    # train_dataset = torchvision.datasets.FashionMNIST(
    #         root=args.data_dir,
    #         train=True,
    #         transform=torchvision.transforms.ToTensor(),
    #         download=True)

    # test_dataset = torchvision.datasets.FashionMNIST(
    #         root=args.data_dir,
    #         train=False,
    #         transform=torchvision.transforms.ToTensor(),
    #         download=True)

    train_loader = torch.utils.data.DataLoader(  
        dataset=train_dataset,  
        batch_size=args.b,  
        shuffle=True,  
        drop_last=True,  
        num_workers=args.j,  
        pin_memory=True  
    )  
    test_loader = torch.utils.data.DataLoader(  
        dataset=test_dataset,  
        batch_size=args.b,  
        shuffle=False,  
        drop_last=False,  
        num_workers=args.j,  
        pin_memory=True  
    )  


    # 混合精度和优化器  
    scaler = amp.GradScaler() if args.amp else None  
    
    optimizer = (  
        torch.optim.Adam(net.parameters(), lr=args.lr)   
        if args.opt == 'adam'   
        else torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9)  
    )  

    start_epoch = 0  
    if args.resume is not None:  
        checkpoint = torch.load(args.resume)  
        net.load_state_dict(checkpoint['model_state_dict'])  
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])  
        start_epoch = checkpoint['epoch'] + 1

    # Poisson 编码器  
    encoder = encoding.PoissonEncoder()  

    # 训练日志目录  
    out_dir = os.path.join(  
        args.out_dir,   
        f'TransformerSNN_T{args.T}_MNIST'  
    )  
    os.makedirs(out_dir, exist_ok=True)  

    writer = SummaryWriter(log_dir=out_dir)

    max_test_acc = -1  

    for epoch in range(start_epoch, args.epochs):  
        net.train()  
        train_loss, train_acc, train_samples = 0, 0, 0  

        for img, label in train_loader:  
            optimizer.zero_grad()  
            
            img = img.to(args.device)  
            label = label.to(args.device)  
            label_onehot = F.one_hot(label, 10).float()

            out_fr = 0.  
            for t in range(args.T):  
                # Poisson 编码  
                encoded_img = encoder(img)  
                # 提取特征  
                features = net(encoded_img)  
                # 脉冲神经元处理  
                out_fr += net.snn_classifier(features)  
                
            out_fr = out_fr / args.T  
            loss = F.mse_loss(out_fr, label_onehot)  

            # 反向传播  
            loss.backward()  
            optimizer.step()  

            # 统计指标  
            train_samples += label.numel()  
            train_loss += loss.item() * label.numel()  
            train_acc += (out_fr.argmax(1) == label).float().sum().item()  

            # 重置网络状态  
            functional.reset_net(net)  

        # 验证  
        net.eval()  
        test_loss, test_acc, test_samples = 0, 0, 0  
        with torch.no_grad():  
            for img, label in test_loader:  
                img = img.to(args.device)  
                label = label.to(args.device)
                if label.dtype != torch.long:  
                    label = label.long()  
                label_onehot = F.one_hot(label, 10).float()  

                out_fr = 0.  
                for t in range(args.T):  
                    encoded_img = encoder(img)  
                    features = net(encoded_img)  
                    out_fr += net.snn_classifier(features)  
                
                out_fr = out_fr / args.T  
                loss = F.mse_loss(out_fr, label_onehot)  

                test_samples += label.numel()  
                test_loss += loss.item() * label.numel()  
                test_acc += (out_fr.argmax(1) == label).float().sum().item()  

                functional.reset_net(net)  

        # 输出训练信息  
        train_loss /= train_samples  
        train_acc /= train_samples  
        test_loss /= test_samples  
        test_acc /= test_samples  

        writer.add_scalar('Loss/train', train_loss, epoch)  
        writer.add_scalar('Loss/test', test_loss, epoch)  
        writer.add_scalar('Accuracy/train', train_acc, epoch)  
        writer.add_scalar('Accuracy/test', test_acc, epoch)

        # 保存最佳模型  
        if test_acc > max_test_acc:  
            max_test_acc = test_acc  
            torch.save(net.state_dict(), os.path.join(out_dir, 'best_model.pth'))  

        print(f'Epoch {epoch}, '  
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '  
              f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, '  
              f'Best Test Acc: {max_test_acc:.4f}')  

    # 保存模型  
    torch.save(net.state_dict(), os.path.join(out_dir, 'final_model.pth'))  
    writer.close()

    # tensorboard --logdir=./logs
    # MNIST:131m55.6s
    # FashionMNIST:121m9.4s
 
if __name__ == '__main__':  
    main()

TransformerSNN(
  (patch_embed): Sequential(
    (0): Conv2d(1, 128, kernel_size=(7, 7), stride=(7, 7))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (snn_classifier): Sequential(
    (0): Linear(in_features=2048, out_fe

  scaler = amp.GradScaler() if args.amp else None


Epoch 0, Train Loss: 0.0121, Train Acc: 0.9231, Test Loss: 0.0045, Test Acc: 0.9707, Best Test Acc: 0.9707
Epoch 1, Train Loss: 0.0039, Train Acc: 0.9764, Test Loss: 0.0031, Test Acc: 0.9804, Best Test Acc: 0.9804
Epoch 2, Train Loss: 0.0027, Train Acc: 0.9844, Test Loss: 0.0028, Test Acc: 0.9814, Best Test Acc: 0.9814
Epoch 3, Train Loss: 0.0021, Train Acc: 0.9875, Test Loss: 0.0024, Test Acc: 0.9854, Best Test Acc: 0.9854
Epoch 4, Train Loss: 0.0017, Train Acc: 0.9903, Test Loss: 0.0025, Test Acc: 0.9838, Best Test Acc: 0.9854
Epoch 5, Train Loss: 0.0015, Train Acc: 0.9916, Test Loss: 0.0024, Test Acc: 0.9840, Best Test Acc: 0.9854
Epoch 6, Train Loss: 0.0012, Train Acc: 0.9934, Test Loss: 0.0021, Test Acc: 0.9867, Best Test Acc: 0.9867
Epoch 7, Train Loss: 0.0012, Train Acc: 0.9930, Test Loss: 0.0024, Test Acc: 0.9848, Best Test Acc: 0.9867
Epoch 8, Train Loss: 0.0010, Train Acc: 0.9944, Test Loss: 0.0019, Test Acc: 0.9881, Best Test Acc: 0.9881
Epoch 9, Train Loss: 0.0008, Train Ac