In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dset
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
import copy
# 从训练集的50000个样本中，取49000个作为训练集，剩余1000个作为验证集
NUM_TRAIN = 49000

# 数据预处理，减去cifar-10数据均值
transform_normal = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))
])
# 数据增强
transform_aug = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 加载训练集
cifar10_train = dset.CIFAR10('./dataset', train=True, download=True, transform=transform_normal)
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))

# 加载验证集
cifar10_val = dset.CIFAR10('./dataset', train=True, download=True, transform=transform_normal)
loader_val = DataLoader(cifar10_val, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN, 50000)))

# 加载测试集
cifar10_test = dset.CIFAR10('./dataset', train=False, download=True, transform=transform_normal)
loader_test = DataLoader(cifar10_test, batch_size=64)
USE_GPU = True
dtype = torch.float32
print_every = 100

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('using device:', device)

using device: cuda


In [2]:
# 验证模型在验证集或者测试集上的准确率
def check_accuracy(loader, model):
    if loader.dataset.train:
        print('Checking accuracy on validation set')
    else:
        print('Checking accuracy on test set')
    num_correct = 0
    num_samples = 0
    model.eval()   # set model to evaluation mode
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device=device, dtype=dtype)
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            _,preds = scores.max(1)
            num_correct += (preds==y).sum()
            num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f)' % (num_correct, num_samples, 100 *acc ))
        return acc
def train_model(model, optimizer, accs,model_name,epochs=1, scheduler=None):
    '''
    Parameters:
    - model: A Pytorch Module giving the model to train.
    - optimizer: An optimizer object we will use to train the model
    - epochs: A Python integer giving the number of epochs to train
    Returns: best model
    '''
    best_model_wts = None
    best_acc = 0.0
    model = model.to(device=device) # move the model parameters to CPU/GPU
    for e in range(epochs):
        if scheduler:
            scheduler.step()
        for t,(x,y) in enumerate(loader_train):
            model.train()   # set model to training mode
            x = x.to(device, dtype=dtype)
            y = y.to(device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch %d, loss=%.4f' % (e, loss.item()))
        acc = check_accuracy(loader_val, model)
        accs.append(100*acc)
        if acc > best_acc:
            best_model_wts = copy.deepcopy(model.state_dict())
            best_acc = acc
    print('best_acc:',best_acc)
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), model_name)
    return model

In [4]:
import torchvision.models as models
from torch.nn import MultiheadAttention

class build_model(nn.Module):
    def __init__(self, num_classes=10):
        super(build_model, self).__init__()
        
        # 加载预训练的VGG16
        vgg = models.vgg16(weights='IMAGENET1K_V1')
        
        # 提取不同层级的VGG特征
        self.vgg_block1 = nn.Sequential(*list(vgg.features.children())[:5])   # 输出64x32x32
        self.vgg_block2 = nn.Sequential(*list(vgg.features.children())[5:10]) # 输出128x16x16
        self.vgg_block3 = nn.Sequential(*list(vgg.features.children())[10:17])# 输出256x8x8
        self.vgg_block4 = nn.Sequential(*list(vgg.features.children())[17:20]) # 输出512x8x8
    
        # GRU层
        self.gru1 = nn.GRU(input_size=3, hidden_size=64, batch_first=True)
        self.gru2 = nn.GRU(input_size=64, hidden_size=128, batch_first=True)
        self.gru3 = nn.GRU(input_size=128, hidden_size=256, batch_first=True)
        self.gru4 = nn.GRU(input_size=256, hidden_size=512, batch_first=True)
        self.gru5 = nn.GRU(input_size=512, hidden_size=256, batch_first=True)

         # 解码器GRU层（新增）
        self.decoder_gru1 = nn.GRU(input_size=512, hidden_size=256, batch_first=True)  # 从vgg4回到vgg3
        self.decoder_gru2 = nn.GRU(input_size=256, hidden_size=128, batch_first=True)  # 从vgg3回到vgg2
        self.decoder_gru3 = nn.GRU(input_size=128, hidden_size=64, batch_first=True)  # 从vgg2回到vgg1
        self.decoder_gru4 = nn.GRU(input_size=64, hidden_size=3, batch_first=True)   # 从vgg1回到输入
        
        # 注意力层（每层使用对应的VGG特征）
        self.attn1 = MultiheadAttention(embed_dim=64, kdim=64, vdim=64, num_heads=4)
        self.attn2 = MultiheadAttention(embed_dim=128, kdim=128, vdim=128, num_heads=8)
        self.attn3 = MultiheadAttention(embed_dim=256, kdim=256, vdim=256, num_heads=8)
        self.attn4 = MultiheadAttention(embed_dim=512, kdim=512, vdim=512, num_heads=8)
        self.attn5 = MultiheadAttention(embed_dim=256, kdim=256, vdim=256, num_heads=8)

        # 解码器注意力层（新增）
        self.decoder_attn1 = MultiheadAttention(embed_dim=256, kdim=256, vdim=256, num_heads=8)  # vgg3特征
        self.decoder_attn2 = MultiheadAttention(embed_dim=128, kdim=128, vdim=128, num_heads=8)  # vgg2特征
        self.decoder_attn3 = MultiheadAttention(embed_dim=64, kdim=64, vdim=64, num_heads=4)    # vgg1特征
        self.decoder_attn4 = MultiheadAttention(embed_dim=3, kdim=3, vdim=3, num_heads=3)       # 输入特征
       
        # 层归一化
        self.norm1 = nn.LayerNorm(64)
        self.norm2 = nn.LayerNorm(128)
        self.norm3 = nn.LayerNorm(256)
        self.norm4 = nn.LayerNorm(512)
        self.norm5 = nn.LayerNorm(256)

        # 解码器层归一化（新增）
        self.decoder_norm1 = nn.LayerNorm(256)
        self.decoder_norm2 = nn.LayerNorm(128)
        self.decoder_norm3 = nn.LayerNorm(64)
        self.decoder_norm4 = nn.LayerNorm(3)

        # 特征融合层（新增）
        self.fusion1 = nn.Linear(256 + 256, 256)  # 融合编码器和解码器特征
        self.fusion2 = nn.Linear(128 + 128, 128)
        self.fusion3 = nn.Linear(64 + 64, 64)

        self.middrop=nn.Dropout(0.25)
        # 分类器
        self.fc1 = nn.Linear(515, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3=nn.Linear(32,num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def _to_sequence(self, features):
        """将特征图转换为序列"""
        batch_size, channels, height, width = features.size()
        seq = features.view(batch_size, channels, -1).permute(0, 2, 1)#[batch,h*w,channels]
        return seq


    def _adjust_sequence_length(self, seq, target_length):
        """调整序列长度"""
        # seq: [batch, seq_len, features]
        if seq.size(1) == target_length:
            return seq
        
        seq = seq.permute(0, 2, 1)  # [batch, features, seq_len]
        if seq.size(2) < target_length:
            # 上采样
            seq = nn.functional.interpolate(seq, size=target_length, mode='linear', align_corners=False)
        else:
            # 下采样
            seq = nn.AdaptiveAvgPool1d(target_length)(seq)
        seq = seq.permute(0, 2, 1)  # [batch, target_length, features]
        return seq

    def forward(self, x):
        # 提取多尺度VGG特征
        vgg1 = self.vgg_block1(x)  # [batch, 64, 32, 32]
        vgg2 = self.vgg_block2(vgg1)  # [batch, 128, 16, 16]
        vgg3 = self.vgg_block3(vgg2)  # [batch, 256, 8, 8]
        vgg4 = self.vgg_block4(vgg3) # [batch, 512, 8, 8]
        
        # 转换为序列
        vgg1_seq = self._to_sequence(vgg1)  # [batch, 1024, 64]
        vgg2_seq = self._to_sequence(vgg2)  # [batch, 256, 128]
        vgg3_seq = self._to_sequence(vgg3)  # [batch, 64, 256]
        vgg4_seq = self._to_sequence(vgg4)  # [batch, 64, 512]

        # 第一层处理
        x_seq = self._to_sequence(x)  
        gru1_out, _ = self.gru1(x_seq)  # [batch, 1024, 64]
        
        # 第一层注意力
        query1 = vgg1_seq.permute(1, 0, 2)  # [1024, batch, 64]
        key1 = gru1_out.permute(1, 0, 2)     # [1024, batch, 64]
        value1 = gru1_out.permute(1, 0, 2)   # [1024, batch, 64]
        attn_out1, _ = self.attn1(query1, key1, value1)
        attn_out1 = attn_out1.permute(1, 0, 2)  # [batch, 1024, 64]
        
        # 残差连接 + 归一化
        res1 = vgg1_seq + attn_out1
        norm1 = self.norm1(res1)
        
        # 第二层处理
        # 调整序列长度
        norm1 = norm1.permute(0, 2, 1)  # [batch, 64, 1024] [B,C,L]
        norm1 = nn.AdaptiveAvgPool1d(256)(norm1)  # 调整序列长度 [batch,64,256]
        norm1 = norm1.permute(0, 2, 1)  # [batch, 256, 64] [B,L,C]
        
        gru2_out, _ = self.gru2(norm1)  # [batch, 256, 128]
        
        # 第二层注意力
        query2 = vgg2_seq.permute(1,0,2)  # [256, batch, 128]
        key2 =  gru2_out.permute(1, 0, 2)    # [256, batch, 128]
        value2 =  gru2_out.permute(1, 0, 2)   # [256, batch, 128] 
        attn_out2, _ = self.attn2(query2, key2, value2)#[256,batch,128]  [长度，batch,维度]
        attn_out2 = attn_out2.permute(1, 0, 2)  # [batch, 256, 128]->[batch,长度,维度]
        
        # 残差连接 + 归一化
        res2 = vgg2_seq + attn_out2
        norm2 = self.norm2(res2)
        
        # 第三层处理
        # 调整序列长度
        norm2 = norm2.permute(0, 2, 1)  # [batch, 128, 256]
        norm2 = nn.AdaptiveAvgPool1d(64)(norm2)  # 调整序列长度
        norm2 = norm2.permute(0, 2, 1)  # [batch,64,128]
        
        gru3_out, _ = self.gru3(norm2)  # [batch, 64, 256]
        
        # 第三层注意力
        query3 = vgg3_seq.permute(1, 0, 2)  # [64, batch, 256]
        key3 = gru3_out.permute(1, 0, 2)    # [64, batch, 256]
        value3 = gru3_out.permute(1, 0, 2)   # [64, batch, 256]
        attn_out3, _ = self.attn3(query3, key3, value3)#[64,batch,256]
        attn_out3 = attn_out3.permute(1, 0, 2)  # [batch, 64, 256]
        
        # 残差连接 + 归一化
        res3 = vgg3_seq + attn_out3
        norm3 = self.norm3(res3)

        #第四层
        gru4_out,_=self.gru4(norm3) #[batch,64,512]
        query4=vgg4_seq.permute(1,0,2)#[64,batch,512]
        key4=gru4_out.permute(1,0,2)
        value4=gru4_out.permute(1,0,2)
        attn_out4,_=self.attn4(query4,key4,value4)#[64,batch,512]
        attn_out4=attn_out4.permute(1, 0, 2)#[batch,64,512]

        res4=gru4_out+attn_out4
        norm4=self.norm4(res4)
        encoder_final=norm4


        #解码器层
        decoder1_out, _ = self.decoder_gru1(norm4)  # [batch, 64, 256]
        # 解码器注意力：使用vgg3特征
        decoder_query1 = vgg3_seq.permute(1, 0, 2)  # [64, batch, 256]
        decoder_key1 = decoder1_out.permute(1, 0, 2)       # [64, batch, 256]
        decoder_value1 = decoder1_out.permute(1, 0, 2)      # [64, batch, 256]
        decoder_attn_out1, _ = self.decoder_attn1(decoder_query1, decoder_key1, decoder_value1)
        decoder_attn_out1 = decoder_attn_out1.permute(1, 0, 2)  # [batch, 64, 256]
        
        # 残差连接 + 归一化
        decoder_res1 = vgg3_seq + decoder_attn_out1 #[batch,64,256]
        decoder_norm1 = self.decoder_norm1(decoder_res1)
        
         # 特征融合：编码器和解码器特征
        fused1 = torch.cat([norm3, decoder_norm1], dim=-1)  # [batch, 64, 512]
        fused1 = self.fusion1(fused1)  # [batch, 64, 256]

        fused1_adj = self._adjust_sequence_length(fused1, 256)  # [batch, 256, 256]

        #第二层，使用vgg2特征
        decoder2_out, _ = self.decoder_gru2(decoder_norm1)#[batch,256,128]
        decoder_query2 = vgg2_seq.permute(1, 0, 2)   #[256,batch,128]
        decoder_key2 = decoder2_out.permute(1, 0, 2)       #[256,batch,128]
        decoder_value2 = decoder2_out.permute(1, 0, 2)     #[256,batch,128]
        decoder_attn_out2, _ = self.decoder_attn2(decoder_query2, decoder_key2, decoder_value2)
        decoder_attn_out2 = decoder_attn_out2.permute(1, 0, 2) #[batch,256,128]
        # 残差连接 + 归一化
        decoder_res2 = vgg2_seq + decoder_attn_out2
        decoder_norm2 = self.decoder_norm2(decoder_res2)
        # 特征融合
        fused2 = torch.cat([norm2, decoder_norm2], dim=-1)  # [batch, 256, 256]
        fused2 = self.fusion2(fused2)  # [batch, 256, 128]
        
        # 第三层解码：从vgg2回到vgg1
        fused2_adj = self._adjust_sequence_length(fused2, 1024)  # [batch, 1024, 128]
        decoder3_out, _ = self.decoder_gru3(fused2_adj)  # [batch, 1024, 64]
        
        # 解码器注意力：使用vgg1特征
        decoder_query3 = vgg1_seq.permute(1, 0, 2)  # [1024, batch, 64]
        decoder_key3 = decoder3_out.permute(1, 0, 2)        # [1024, batch, 64]
        decoder_value3 = decoder3_out.permute(1, 0, 2)      # [1024, batch, 64]
        decoder_attn_out3, _ = self.decoder_attn3(decoder_query3, decoder_key3, decoder_value3)
        decoder_attn_out3 = decoder_attn_out3.permute(1, 0, 2)  # [batch, 1024, 64]
        
        # 残差连接 + 归一化
        decoder_res3 = vgg1_seq + decoder_attn_out3
        decoder_norm3 = self.decoder_norm3(decoder_res3)
        # 特征融合
        fused3 = torch.cat([norm1, decoder_norm3], dim=-1)  # [batch, 1024, 128]
        fused3 = self.fusion3(fused3)  # [batch, 1024, 64]
        
         # 第四层解码：从vgg1回到输入
        decoder4_out, _ = self.decoder_gru4(fused3)  # [batch, 1024, 3]
        
        # 解码器注意力：使用输入特征
        decoder_query4 = x_seq.permute(1, 0, 2)  # [1024, batch, 3]
        decoder_key4 = decoder4_out.permute(1, 0, 2)           # [1024, batch, 3]
        decoder_value4 = decoder4_out.permute(1, 0, 2)        # [1024, batch, 3]
        decoder_attn_out4, _ = self.decoder_attn4(decoder_query4, decoder_key4, decoder_value4)
        decoder_attn_out4 = decoder_attn_out4.permute(1, 0, 2)  # [batch, 1024, 3]
        
        # 残差连接 + 归一化
        decoder_res4 = x_seq + decoder_attn_out4
        decoder_final = self.decoder_norm4(decoder_res4)  # [batch, 1024, 3]

         # === 最终特征融合和分类 ===
        # 使用编码器最终特征和解码器最终特征
        encoder_pooled = torch.mean(encoder_final, dim=1)  # [batch, 512]
        decoder_pooled = torch.mean(decoder_final, dim=1)  # [batch, 3]
        
        # 融合两种特征
        final_features = torch.cat([encoder_pooled, decoder_pooled], dim=1)  # [batch, 515]
        
        # 分类器
        x=self.relu(self.fc1(final_features))
        x = self.fc2(x)
        x=self.fc3(x)
        return x

In [5]:
# 训练更多代数，并应用学习率衰减
cifar10_train = dset.CIFAR10('./dataset', train=True, download=True, transform=transform_aug)
loader_train = DataLoader(cifar10_train, batch_size=64, sampler=sampler.SubsetRandomSampler(range(NUM_TRAIN)))
learning_rate = 1e-2
accs=[]
model = build_model()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
scheduler = lr_scheduler.StepLR(optimizer, step_size=15,gamma=0.1)
best_model= train_model(model, optimizer,accs,'multiheadgru.pth',50, scheduler)
val_acc=check_accuracy(loader_test, best_model)

ValueError: sampler option is mutually exclusive with shuffle

In [None]:

epochs = np.arange(0,50 )
# 计算每条曲线最大值位置
max_idx_gru = np.argmax(accs)

plt.figure(figsize=(10,6))

# 绘制三条曲线
plt.plot(epochs, accs, '-o', color='red', label='MultiHeadGRU')

# 标注最大值点
plt.scatter(epochs[max_idx_gru], accs[max_idx_gru], color='red', s=100, marker='^')

# 调整最大值标签位置，避免重叠
plt.text(epochs[max_idx_gru]+0.5, accs[max_idx_gru]+0.5, f"Max: {acc_multiheadgru[max_idx_gru]:.2f}%", color='red')
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.ylim(50, 100)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.6)
plt.savefig('multiheadgru.png')
plt.show()