In [1]:
import torch
import torch.nn.functional as F
train1 = torch.load(r'E:\snn_final\ColoredMNIST\train1.pt',weights_only=False)
train2 = torch.load(r'E:\snn_final\ColoredMNIST\train2.pt',weights_only=False)
test = torch.load(r'E:\snn_final\ColoredMNIST\test.pt',weights_only=False)

from torchvision import transforms
import numpy as np

def dataset_load(raw_dataset):
    x=[]
    y=[]
    for one_data in raw_dataset:
        y.append(one_data[1])
        image=one_data[0]
        data=np.array(image)
        data=data.transpose(2,0,1)
        x.append(data)
    x=torch.Tensor(np.array(x))
    y=torch.Tensor(np.array(y))
    ds=torch.utils.data.TensorDataset(x,y)
    return ds

train1_data = dataset_load(train1)
train2_data = dataset_load(train2)
test_data = dataset_load(test)

In [None]:
import torch  
import torch.nn as nn  
import numpy as np  
from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer  
from torch.utils.tensorboard import SummaryWriter
import os

class ReversalLayer(torch.autograd.Function):  
    @staticmethod  
    def forward(ctx, x, alpha):  
        ctx.alpha = alpha  
        return x.view_as(x)  

    @staticmethod  
    def backward(ctx, grad_output):  
        output = grad_output.neg() * ctx.alpha  
        return output, None  

class SpikingDANN_mamba(nn.Module):  
    def __init__(self):  
        super(SpikingDANN_mamba, self).__init__()  
        
        # 主路径是一个基础的卷积网络，提取初始特征
        self.main_branch = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),  
            nn.MaxPool2d(2, 2),  
            nn.ReLU(inplace=True),
        )

        # 分支路径有不同尺度的卷积核，用于提取多尺度特征
        self.branch1 = nn.Sequential(  # 分支1，卷积核大小为3x3  
            nn.Conv2d(16, 32, kernel_size=3, padding=1),  
            nn.BatchNorm2d(32),  
            nn.MaxPool2d(2, 2),  
            nn.ReLU(inplace=True),  
        )  

        self.branch2 = nn.Sequential(  # 分支2，卷积核大小为5x5  
            nn.Conv2d(16, 32, kernel_size=5, padding=2),  
            nn.BatchNorm2d(32),  
            nn.MaxPool2d(2, 2),  
            nn.ReLU(inplace=True),  
        )  

        self.branch3 = nn.Sequential(  # 分支3，卷积核大小为7x7  
            nn.Conv2d(16, 32, kernel_size=7, padding=3),  
            nn.BatchNorm2d(32),  
            nn.MaxPool2d(2, 2),  
            nn.ReLU(inplace=True),  
        )  

        # 通过线性层和softmax实现动态加权融合，根据输入的统计信息为每个分支赋权
        self.fusion_weights = nn.Sequential(  
            nn.Linear(3, 3),  # 分支数为3  
            nn.Softmax(dim=1)  # 动态加权融合分数  
        ) 

        self.feature_fusion = nn.Conv2d(32 * 3, 16, kernel_size=1)  # 将3个分支拼接  
        
        # SNN Classifier with Leaky Integrate-and-Fire neurons  
        self.classifier = nn.Sequential(  
            nn.Linear(784, 32),  
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan()),  
            nn.Linear(32, 32),  
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan()),  
            nn.Linear(32, 2),  
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan()),  
        )  
        
        # Domain Discriminator  
        self.domain_discriminator = nn.Sequential(  
            nn.Linear(784, 16),  
            nn.Dropout(0.3),
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan()),  
            nn.Linear(16, 2), 
        )  

    def forward(self, input_spikes, alpha=1.0):  
        main_features = self.main_branch(input_spikes)
        branch1_features = self.branch1(main_features)
        branch2_features = self.branch2(main_features)
        branch3_features = self.branch3(main_features)
        branch_outputs = torch.stack([
            branch1_features.mean(dim=(2, 3)),  # 平均池化 -> (B, C)  
            branch2_features.mean(dim=(2, 3)),   
            branch3_features.mean(dim=(2, 3)) 
        ], dim=-1)

        fusion_weights = self.fusion_weights(branch_outputs.mean(dim=1))
        weighted_branch1 = fusion_weights[:, 0].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * branch1_features  
        weighted_branch2 = fusion_weights[:, 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * branch2_features  
        weighted_branch3 = fusion_weights[:, 2].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * branch3_features  
        fused_features = torch.cat([weighted_branch1, weighted_branch2, weighted_branch3], dim=1)  

        features = self.feature_fusion(fused_features)  
        flattened_features = features.view(input_spikes.shape[0], -1)

        class_output = self.classifier(flattened_features)  
        domain_output = self.domain_discriminator(ReversalLayer.apply(flattened_features, alpha))  
        
        return class_output, domain_output  
      
def evaluate_accuracy(data_iter, net, encoder, T):  
    acc_sum, n = 0.0, 0
    net.eval()  
    with torch.no_grad():  
        for X, y in data_iter:   
            for t in range(T):
                spike_input = encoder(X)  # Encode to spikes  
                class_output, _ = net(spike_input)  
                if t == 0:
                    out_fr = torch.zeros_like(class_output) 
                out_fr += class_output
            out_fr = out_fr / T
            acc_sum += (out_fr.argmax(1) == y).float().sum().item()  
            n += y.numel()
            functional.reset_net(net)
    return acc_sum / n 
    

def train_snn(net, train1_data, train2_data, test_data):  
    class parser:  
        def __init__(self):  
            self.T = 5  
            self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'  
            self.epochs = 10  
            self.b = 128  
            self.j = 4  
            self.out_dir = './logs'  
            self.resume = None
            self.amp = True  
            self.opt = 'adam'  
            self.lr = 1e-3  
            self.tau = 2.0  
            self.alpha_domain = 1.0
    
    args = parser()
    
    encoder = encoding.PoissonEncoder()
    loss_class_func = nn.CrossEntropyLoss()
    loss_domain_func = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(net.parameters(),lr=args.lr)

    out_dir = os.path.join(  
        args.out_dir,   
        f'Mamba_DANN_SNN_T{args.T}'  
    )  
    os.makedirs(out_dir, exist_ok=True)  

    writer = SummaryWriter(log_dir=out_dir)
    
    for epoch in range(args.epochs):  
        net.train()
        train1_acc, train2_acc, class_loss, domain1_loss, domain2_loss = 0,0,0,0,0 
        domain1_acc, domain2_acc = 0,0
        train_samples = 0
        p = float(epoch)/float(args.epochs)
        alpha = 2. / (1. + np.exp(-10 * p)) - 1  
        for (source_img, source_class), (target_img, target_class) in zip(train1_data, train2_data):  
            optimizer.zero_grad()  
            source_spikes = encoder(source_img) # 128,3,28,28
            target_spikes = encoder(target_img)   
            source_domain = torch.zeros(source_img.shape[0],dtype=torch.long).to(source_img.device)  
            target_domain = torch.ones(target_img.shape[0],dtype=torch.long).to(target_img.device)  
            source_class = source_class.long()
            target_class = target_class.long()
 
            for t in range(args.T):
                class_output_source, domain_output_source = net(source_spikes,alpha=alpha)  
                class_output_target, domain_output_target = net(target_spikes,alpha=alpha)  
                
                if t == 0:
                    out_fr_class_s = torch.zeros_like(class_output_source)  
                    out_fr_domain_s = torch.zeros_like(domain_output_source)  
                    out_fr_class_t = torch.zeros_like(class_output_target)  
                    out_fr_domain_t = torch.zeros_like(domain_output_target) 
                
                out_fr_class_s += class_output_source
                out_fr_domain_s += domain_output_source
                out_fr_class_t += class_output_target
                out_fr_domain_t += domain_output_target
            out_fr_class_s = out_fr_class_s / args.T
            out_fr_domain_s = out_fr_domain_s / args.T
            out_fr_class_t = out_fr_class_t / args.T
            out_fr_domain_t = out_fr_domain_t / args.T

            # Calculate losses  
            loss_class = loss_class_func(out_fr_class_s, source_class)  
            loss_domain_source = loss_domain_func(out_fr_domain_s, source_domain)  
            loss_domain_target = loss_domain_func(out_fr_domain_t, target_domain)   
            loss = loss_class + loss_domain_source + loss_domain_target  
            
            loss.backward()  # Backpropagation  
            optimizer.step()  # Optimization  

            train_samples += source_class.size(0)
            class_loss += loss_class.item() * source_class.size(0)
            domain1_loss += loss_domain_source.item() * source_class.size(0)
            domain2_loss += loss_domain_target.item() * source_class.size(0)
            train1_acc += (out_fr_class_s.argmax(1) == source_class).float().sum().item()
            train2_acc += (out_fr_class_t.argmax(1) == target_class).float().sum().item()
            domain1_acc += (out_fr_domain_s.argmax(1) == source_domain).float().sum().item()
            domain2_acc += (out_fr_domain_t.argmax(1) == target_domain).float().sum().item()

            functional.reset_net(net)

        class_loss /= train_samples
        domain1_loss /= train_samples
        domain2_loss /= train_samples
        train1_acc /= train_samples
        train2_acc /= train_samples
        domain1_acc /= train_samples
        domain2_acc /= train_samples

        acc = evaluate_accuracy(test_data, net, encoder, args.T)

        writer.add_scalar('Class Acc on source', train1_acc, epoch)
        writer.add_scalar('Class Acc on target', train2_acc, epoch)
        writer.add_scalar('Domain Acc on source', domain1_acc, epoch)
        writer.add_scalar('Domain Acc on target', domain2_acc, epoch)
        writer.add_scalar('Class Loss', class_loss, epoch)
        writer.add_scalar('Domain Loss on source', domain1_loss, epoch)
        writer.add_scalar('Domain Loss on target', domain2_loss, epoch)
        writer.add_scalar('Class Acc on test', acc, epoch)
        
        print(f'Epoch {epoch + 1}, Class Loss: {class_loss:.4f}, Domain Source Loss: {domain1_loss:.4f}, Domain Target Loss: {domain2_loss:.4f}')   
        print(f'Class Acc on train_source: {train1_acc:.4f}, Domain Acc on train_source: {domain1_acc:.4f}')  
        print(f'Class Acc on train_target: {train2_acc:.4f}, Domain Acc on train_target: {domain2_acc:.4f}')  
        print(f'Class Acc on test: {acc:.4f}')  


In [9]:
net = SpikingDANN_mamba() 
  
train1_loader = torch.utils.data.DataLoader(train1_data, batch_size=128, shuffle=True)
train2_loader = torch.utils.data.DataLoader(train2_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)
 
train_snn(net, train1_loader, train2_loader, test_loader)  

Epoch 1, Class Loss: 0.6467, Domain Source Loss: 0.6799, Domain Target Loss: 0.6138
Class Acc on train_source: 0.5393, Domain Acc on train_source: 0.5115
Class Acc on train_target: 0.6025, Domain Acc on train_target: 0.7668
Class Acc on test: 0.0997
Epoch 2, Class Loss: 0.5149, Domain Source Loss: 0.6627, Domain Target Loss: 0.6425
Class Acc on train_source: 0.7981, Domain Acc on train_source: 0.6270
Class Acc on train_target: 0.8972, Domain Acc on train_target: 0.6152
Class Acc on test: 0.0998
Epoch 3, Class Loss: 0.5166, Domain Source Loss: 0.6797, Domain Target Loss: 0.6803
Class Acc on train_source: 0.7950, Domain Acc on train_source: 0.5864
Class Acc on train_target: 0.8895, Domain Acc on train_target: 0.5575
Class Acc on test: 0.0998
Epoch 4, Class Loss: 0.5167, Domain Source Loss: 0.6986, Domain Target Loss: 0.6867
Class Acc on train_source: 0.7945, Domain Acc on train_source: 0.5050
Class Acc on train_target: 0.8893, Domain Acc on train_target: 0.5610
Class Acc on test: 0.1124
