In [1]:
import torch
import torch.nn.functional as F
train1 = torch.load(r'E:\snn_final\ColoredMNIST\train1.pt',weights_only=False)
train2 = torch.load(r'E:\snn_final\ColoredMNIST\train2.pt',weights_only=False)
test = torch.load(r'E:\snn_final\ColoredMNIST\test.pt',weights_only=False)

from torchvision import transforms
import numpy as np

def dataset_load(raw_dataset):
    x=[]
    y=[]
    for one_data in raw_dataset:
        y.append(one_data[1])
        image=one_data[0]
        data=np.array(image)
        data=data.transpose(2,0,1)
        x.append(data)
    x=torch.Tensor(np.array(x))
    y=torch.Tensor(np.array(y))
    ds=torch.utils.data.TensorDataset(x,y)
    return ds

train1_data = dataset_load(train1)
train2_data = dataset_load(train2)
test_data = dataset_load(test)

In [None]:
import torch  
import torch.nn as nn  
import numpy as np  
from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer  
from torch.utils.tensorboard import SummaryWriter
import os

class ReversalLayer(torch.autograd.Function):  
    @staticmethod  
    def forward(ctx, x, alpha):  
        ctx.alpha = alpha  
        return x.view_as(x)  

    @staticmethod  
    def backward(ctx, grad_output):  
        output = grad_output.neg() * ctx.alpha  
        return output, None  

class SpikingDANN(nn.Module):  
    def __init__(self, img_size=28, patch_size=7, in_channels=3, embed_dim=128, num_heads=4, num_classes=2, tau=2.0):  
        super(SpikingDANN, self).__init__()  
         
        self.patch_embed = nn.Sequential(  
            nn.Conv2d(in_channels, embed_dim,   
                      kernel_size=patch_size,   
                      stride=patch_size),  
            layer.BatchNorm2d(embed_dim)  
        )  
        
        # 位置编码  
        num_patches = (img_size // patch_size) ** 2  
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))  
        
        # Transformer 编码器  
        self.transformer_encoder = nn.TransformerEncoder(  
            nn.TransformerEncoderLayer(  
                d_model=embed_dim,   
                nhead=num_heads,  
                dim_feedforward=embed_dim*4,  
                activation=F.relu  
            ),  
            num_layers=2  
        )  
        
        # SNN Classifier with Leaky Integrate-and-Fire neurons  
        self.classifier = nn.Sequential(  
            nn.Linear(16 * 128, 32),  
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan()),  
            nn.Linear(32, 32),  
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan()),  
            nn.Linear(32, num_classes),  
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan()),  
        )  
        
        # Domain Discriminator  
        self.domain_discriminator = nn.Sequential(  
            nn.Linear(16 * 128, 16),  
            nn.Dropout(0.3),
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan()),  
            nn.Linear(16, 2), 
        )  

    def forward(self, input_spikes, alpha=1.0):  
        x = self.patch_embed(input_spikes)
        B, C, H, W = x.shape  
        x = x.view(B, C, H*W).transpose(1, 2)   
        x += self.pos_embed  
        x = self.transformer_encoder(x)  
        feature = x.view(B, -1)  

        class_output = self.classifier(feature)  
        domain_output = self.domain_discriminator(ReversalLayer.apply(feature, alpha))  
        
        return class_output, domain_output  
      
def evaluate_accuracy(data_iter, net, encoder, T):  
    acc_sum, n = 0.0, 0
    net.eval()  
    with torch.no_grad():  
        for X, y in data_iter:   
            for t in range(T):
                spike_input = encoder(X)  # Encode to spikes  
                class_output, _ = net(spike_input)  
                if t == 0:
                    out_fr = torch.zeros_like(class_output) 
                out_fr += class_output
            out_fr = out_fr / T
            acc_sum += (out_fr.argmax(1) == y).float().sum().item()  
            n += y.numel()
            functional.reset_net(net)
    return acc_sum / n 
    

def train_snn(net, train1_data, train2_data, test_data):  
    class parser:  
        def __init__(self):  
            self.T = 1  
            self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'  
            self.epochs = 10  
            self.b = 128  
            self.j = 4  
            self.out_dir = './logs'  
            self.resume = None
            self.amp = True  
            self.opt = 'adam'  
            self.lr = 1e-3  
            self.tau = 2.0  
            self.alpha_domain = 1.0
    
    args = parser()
    
    encoder = encoding.PoissonEncoder()
    loss_class_func = nn.CrossEntropyLoss()
    loss_domain_func = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(net.parameters(),lr=args.lr)

    # 训练日志目录  
    out_dir = os.path.join(  
        args.out_dir,   
        f'Mamba_DANN_SNN_T{args.T}'  
    )  
    os.makedirs(out_dir, exist_ok=True)  

    writer = SummaryWriter(log_dir=out_dir)
    
    for epoch in range(args.epochs):  
        net.train()
        train1_acc, train2_acc, class_loss, domain1_loss, domain2_loss = 0,0,0,0,0 
        domain1_acc, domain2_acc = 0,0
        train_samples = 0
        p = float(epoch)/float(args.epochs)
        alpha = 2. / (1. + np.exp(-10 * p)) - 1  
        for (source_img, source_class), (target_img, target_class) in zip(train1_data, train2_data):  
            optimizer.zero_grad()  
            source_spikes = encoder(source_img) # 128,3,28,28
            target_spikes = encoder(target_img)   
            source_domain = torch.zeros(source_img.shape[0],dtype=torch.long).to(source_img.device)  
            target_domain = torch.ones(target_img.shape[0],dtype=torch.long).to(target_img.device)  
            source_class = source_class.long()
            target_class = target_class.long()
 
            for t in range(args.T):
                class_output_source, domain_output_source = net(source_spikes,alpha=alpha)  
                class_output_target, domain_output_target = net(target_spikes,alpha=alpha)  
                
                if t == 0:
                    out_fr_class_s = torch.zeros_like(class_output_source)  
                    out_fr_domain_s = torch.zeros_like(domain_output_source)  
                    out_fr_class_t = torch.zeros_like(class_output_target)  
                    out_fr_domain_t = torch.zeros_like(domain_output_target) 
                
                out_fr_class_s += class_output_source
                out_fr_domain_s += domain_output_source
                out_fr_class_t += class_output_target
                out_fr_domain_t += domain_output_target
            out_fr_class_s = out_fr_class_s / args.T
            out_fr_domain_s = out_fr_domain_s / args.T
            out_fr_class_t = out_fr_class_t / args.T
            out_fr_domain_t = out_fr_domain_t / args.T

            # Calculate losses  
            loss_class = loss_class_func(out_fr_class_s, source_class)  
            loss_domain_source = loss_domain_func(out_fr_domain_s, source_domain)  
            loss_domain_target = loss_domain_func(out_fr_domain_t, target_domain)   
            loss = loss_class + loss_domain_source + loss_domain_target  
            
            loss.backward()  # Backpropagation  
            optimizer.step()  # Optimization  

            train_samples += source_class.size(0)
            class_loss += loss_class.item() * source_class.size(0)
            domain1_loss += loss_domain_source.item() * source_class.size(0)
            domain2_loss += loss_domain_target.item() * source_class.size(0)
            train1_acc += (out_fr_class_s.argmax(1) == source_class).float().sum().item()
            train2_acc += (out_fr_class_t.argmax(1) == target_class).float().sum().item()
            domain1_acc += (out_fr_domain_s.argmax(1) == source_domain).float().sum().item()
            domain2_acc += (out_fr_domain_t.argmax(1) == target_domain).float().sum().item()

            functional.reset_net(net)

        class_loss /= train_samples
        domain1_loss /= train_samples
        domain2_loss /= train_samples
        train1_acc /= train_samples
        train2_acc /= train_samples
        domain1_acc /= train_samples
        domain2_acc /= train_samples

        acc = evaluate_accuracy(test_data, net, encoder, args.T)

        writer.add_scalar('Class Acc on source', train1_acc, epoch)
        writer.add_scalar('Class Acc on target', train2_acc, epoch)
        writer.add_scalar('Domain Acc on source', domain1_acc, epoch)
        writer.add_scalar('Domain Acc on target', domain2_acc, epoch)
        writer.add_scalar('Class Loss', class_loss, epoch)
        writer.add_scalar('Domain Loss on source', domain1_loss, epoch)
        writer.add_scalar('Domain Loss on target', domain2_loss, epoch)
        writer.add_scalar('Class Acc on test', acc, epoch)
        
        print(f'Epoch {epoch + 1}, Class Loss: {class_loss:.4f}, Domain Source Loss: {domain1_loss:.4f}, Domain Target Loss: {domain2_loss:.4f}')   
        print(f'Class Acc on train_source: {train1_acc:.4f}, Domain Acc on train_source: {domain1_acc:.4f}')  
        print(f'Class Acc on train_target: {train2_acc:.4f}, Domain Acc on train_target: {domain2_acc:.4f}')  
        print(f'Class Acc on test: {acc:.4f}')  


In [17]:
net = SpikingDANN()  # Move to GPU if available  
  
train1_loader = torch.utils.data.DataLoader(train1_data, batch_size=128, shuffle=True)
train2_loader = torch.utils.data.DataLoader(train2_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=True)

# Start training  
train_snn(net, train1_loader, train2_loader, test_loader)  

Epoch 1, Class Loss: 0.6418, Domain Source Loss: 0.7087, Domain Target Loss: 0.5851
Class Acc on train_source: 0.6037, Domain Acc on train_source: 0.4436
Class Acc on train_target: 0.6061, Domain Acc on train_target: 0.7821
Class Acc on test: 0.0997
Epoch 2, Class Loss: 0.5195, Domain Source Loss: 0.7084, Domain Target Loss: 0.6946
Class Acc on train_source: 0.7920, Domain Acc on train_source: 0.5219
Class Acc on train_target: 0.7103, Domain Acc on train_target: 0.5343
Class Acc on test: 0.1002
Epoch 3, Class Loss: 0.5163, Domain Source Loss: 0.6889, Domain Target Loss: 0.7016
Class Acc on train_source: 0.7944, Domain Acc on train_source: 0.5403
Class Acc on train_target: 0.7358, Domain Acc on train_target: 0.4933
Class Acc on test: 0.0997
Epoch 4, Class Loss: 0.5599, Domain Source Loss: 0.6929, Domain Target Loss: 0.6835
Class Acc on train_source: 0.7589, Domain Acc on train_source: 0.5010
Class Acc on train_target: 0.7749, Domain Acc on train_target: 0.5900
Class Acc on test: 0.0997
