In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])  # ImageNet标准化
])
transform_train = transforms.Compose([
    transforms.Resize((40, 40)),  # 先放大
    transforms.RandomCrop(32),     # 再随机裁剪
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])  # ImageNet标准化
])
batch_size=4
trainset=torchvision.datasets.CIFAR10(root=r'./',train=True,download=True,transform=transform_train)
testset=torchvision.datasets.CIFAR10(root=r'./',train=False,download=True,transform=transform)
trainloader=torch.utils.data.DataLoader(dataset=trainset,batch_size=batch_size,shuffle=True,num_workers=2)
testloader=torch.utils.data.DataLoader(dataset=testset,batch_size=batch_size,shuffle=False,num_workers=2)
classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

In [None]:
import torchvision.models as models
from torch.nn import MultiheadAttention

class build_model(nn.Module):
    def __init__(self, num_classes=10):
        super(build_model, self).__init__()
        
        # 加载预训练的VGG16
        vgg = models.vgg16(weights='IMAGENET1K_V1')
        
        # 提取不同层级的VGG特征
        self.vgg_block1 = nn.Sequential(*list(vgg.features.children())[:5])   # 输出64x32x32
        self.vgg_block2 = nn.Sequential(*list(vgg.features.children())[5:10]) # 输出128x16x16
        self.vgg_block3 = nn.Sequential(*list(vgg.features.children())[10:17])# 输出256x8x8
        self.vgg_block4 = nn.Sequential(*list(vgg.features.children())[17:20]) # 输出512x8x8
    
        # GRU层
        self.gru1 = nn.GRU(input_size=3, hidden_size=64, batch_first=True)
        self.gru2 = nn.GRU(input_size=64, hidden_size=128, batch_first=True)
        self.gru3 = nn.GRU(input_size=128, hidden_size=256, batch_first=True)
        self.gru4 = nn.GRU(input_size=256, hidden_size=512, batch_first=True)
        self.gru5 = nn.GRU(input_size=512, hidden_size=256, batch_first=True)

         # 解码器GRU层（新增）
        self.decoder_gru1 = nn.GRU(input_size=512, hidden_size=256, batch_first=True)  # 从vgg4回到vgg3
        self.decoder_gru2 = nn.GRU(input_size=256, hidden_size=128, batch_first=True)  # 从vgg3回到vgg2
        self.decoder_gru3 = nn.GRU(input_size=128, hidden_size=64, batch_first=True)  # 从vgg2回到vgg1
        self.decoder_gru4 = nn.GRU(input_size=64, hidden_size=3, batch_first=True)   # 从vgg1回到输入
        
        # 注意力层（每层使用对应的VGG特征）
        self.attn1 = MultiheadAttention(embed_dim=64, kdim=64, vdim=64, num_heads=4)
        self.attn2 = MultiheadAttention(embed_dim=128, kdim=128, vdim=128, num_heads=8)
        self.attn3 = MultiheadAttention(embed_dim=256, kdim=256, vdim=256, num_heads=8)
        self.attn4 = MultiheadAttention(embed_dim=512, kdim=512, vdim=512, num_heads=8)
        self.attn5 = MultiheadAttention(embed_dim=256, kdim=256, vdim=256, num_heads=8)

        # 解码器注意力层（新增）
        self.decoder_attn1 = MultiheadAttention(embed_dim=256, kdim=256, vdim=256, num_heads=8)  # vgg3特征
        self.decoder_attn2 = MultiheadAttention(embed_dim=128, kdim=128, vdim=128, num_heads=8)  # vgg2特征
        self.decoder_attn3 = MultiheadAttention(embed_dim=64, kdim=64, vdim=64, num_heads=4)    # vgg1特征
        self.decoder_attn4 = MultiheadAttention(embed_dim=3, kdim=3, vdim=3, num_heads=3)       # 输入特征
       
        # 层归一化
        self.norm1 = nn.LayerNorm(64)
        self.norm2 = nn.LayerNorm(128)
        self.norm3 = nn.LayerNorm(256)
        self.norm4 = nn.LayerNorm(512)
        self.norm5 = nn.LayerNorm(256)

        # 解码器层归一化（新增）
        self.decoder_norm1 = nn.LayerNorm(256)
        self.decoder_norm2 = nn.LayerNorm(128)
        self.decoder_norm3 = nn.LayerNorm(64)
        self.decoder_norm4 = nn.LayerNorm(3)

        # 特征融合层（新增）
        self.fusion1 = nn.Linear(256 + 256, 256)  # 融合编码器和解码器特征
        self.fusion2 = nn.Linear(128 + 128, 128)
        self.fusion3 = nn.Linear(64 + 64, 64)

        self.middrop=nn.Dropout(0.25)
        # 分类器
        self.fc1 = nn.Linear(515, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3=nn.Linear(32,num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def _to_sequence(self, features):
        """将特征图转换为序列"""
        batch_size, channels, height, width = features.size()
        seq = features.view(batch_size, channels, -1).permute(0, 2, 1)#[batch,h*w,channels]
        return seq


    def _adjust_sequence_length(self, seq, target_length):
        """调整序列长度"""
        # seq: [batch, seq_len, features]
        if seq.size(1) == target_length:
            return seq
        
        seq = seq.permute(0, 2, 1)  # [batch, features, seq_len]
        if seq.size(2) < target_length:
            # 上采样
            seq = nn.functional.interpolate(seq, size=target_length, mode='linear', align_corners=False)
        else:
            # 下采样
            seq = nn.AdaptiveAvgPool1d(target_length)(seq)
        seq = seq.permute(0, 2, 1)  # [batch, target_length, features]
        return seq

    def forward(self, x):
        # 提取多尺度VGG特征
        vgg1 = self.vgg_block1(x)  # [batch, 64, 32, 32]
        vgg2 = self.vgg_block2(vgg1)  # [batch, 128, 16, 16]
        vgg3 = self.vgg_block3(vgg2)  # [batch, 256, 8, 8]
        vgg4 = self.vgg_block4(vgg3) # [batch, 512, 8, 8]
        
        # 转换为序列
        vgg1_seq = self._to_sequence(vgg1)  # [batch, 1024, 64]
        vgg2_seq = self._to_sequence(vgg2)  # [batch, 256, 128]
        vgg3_seq = self._to_sequence(vgg3)  # [batch, 64, 256]
        vgg4_seq = self._to_sequence(vgg4)  # [batch, 64, 512]

        # 第一层处理
        x_seq = self._to_sequence(x)  
        gru1_out, _ = self.gru1(x_seq)  # [batch, 1024, 64]
        
        # 第一层注意力
        query1 = vgg1_seq.permute(1, 0, 2)  # [1024, batch, 64]
        key1 = gru1_out.permute(1, 0, 2)     # [1024, batch, 64]
        value1 = gru1_out.permute(1, 0, 2)   # [1024, batch, 64]
        attn_out1, _ = self.attn1(query1, key1, value1)
        attn_out1 = attn_out1.permute(1, 0, 2)  # [batch, 1024, 64]
        
        # 残差连接 + 归一化
        res1 = vgg1_seq + attn_out1
        norm1 = self.norm1(res1)
        
        # 第二层处理
        # 调整序列长度
        norm1 = norm1.permute(0, 2, 1)  # [batch, 64, 1024] [B,C,L]
        norm1 = nn.AdaptiveAvgPool1d(256)(norm1)  # 调整序列长度 [batch,64,256]
        norm1 = norm1.permute(0, 2, 1)  # [batch, 256, 64] [B,L,C]
        
        gru2_out, _ = self.gru2(norm1)  # [batch, 256, 128]
        
        # 第二层注意力
        query2 = vgg2_seq.permute(1,0,2)  # [256, batch, 128]
        key2 =  gru2_out.permute(1, 0, 2)    # [256, batch, 128]
        value2 =  gru2_out.permute(1, 0, 2)   # [256, batch, 128] 
        attn_out2, _ = self.attn2(query2, key2, value2)#[256,batch,128]  [长度，batch,维度]
        attn_out2 = attn_out2.permute(1, 0, 2)  # [batch, 256, 128]->[batch,长度,维度]
        
        # 残差连接 + 归一化
        res2 = vgg2_seq + attn_out2
        norm2 = self.norm2(res2)
        
        # 第三层处理
        # 调整序列长度
        norm2 = norm2.permute(0, 2, 1)  # [batch, 128, 256]
        norm2 = nn.AdaptiveAvgPool1d(64)(norm2)  # 调整序列长度
        norm2 = norm2.permute(0, 2, 1)  # [batch,64,128]
        
        gru3_out, _ = self.gru3(norm2)  # [batch, 64, 256]
        
        # 第三层注意力
        query3 = vgg3_seq.permute(1, 0, 2)  # [64, batch, 256]
        key3 = gru3_out.permute(1, 0, 2)    # [64, batch, 256]
        value3 = gru3_out.permute(1, 0, 2)   # [64, batch, 256]
        attn_out3, _ = self.attn3(query3, key3, value3)#[64,batch,256]
        attn_out3 = attn_out3.permute(1, 0, 2)  # [batch, 64, 256]
        
        # 残差连接 + 归一化
        res3 = vgg3_seq + attn_out3
        norm3 = self.norm3(res3)

        #第四层
        gru4_out,_=self.gru4(norm3) #[batch,64,512]
        query4=vgg4_seq.permute(1,0,2)#[64,batch,512]
        key4=gru4_out.permute(1,0,2)
        value4=gru4_out.permute(1,0,2)
        attn_out4,_=self.attn4(query4,key4,value4)#[64,batch,512]
        attn_out4=attn_out4.permute(1, 0, 2)#[batch,64,512]

        res4=gru4_out+attn_out4
        norm4=self.norm4(res4)
        encoder_final=norm4


        #解码器层
        decoder1_out, _ = self.decoder_gru1(norm4)  # [batch, 64, 256]
        # 解码器注意力：使用vgg3特征
        decoder_query1 = vgg3_seq.permute(1, 0, 2)  # [64, batch, 256]
        decoder_key1 = decoder1_out.permute(1, 0, 2)       # [64, batch, 256]
        decoder_value1 = decoder1_out.permute(1, 0, 2)      # [64, batch, 256]
        decoder_attn_out1, _ = self.decoder_attn1(decoder_query1, decoder_key1, decoder_value1)
        decoder_attn_out1 = decoder_attn_out1.permute(1, 0, 2)  # [batch, 64, 256]
        
        # 残差连接 + 归一化
        decoder_res1 = vgg3_seq + decoder_attn_out1 #[batch,64,256]
        decoder_norm1 = self.decoder_norm1(decoder_res1)
        
         # 特征融合：编码器和解码器特征
        fused1 = torch.cat([norm3, decoder_norm1], dim=-1)  # [batch, 64, 512]
        fused1 = self.fusion1(fused1)  # [batch, 64, 256]

        fused1_adj = self._adjust_sequence_length(fused1, 256)  # [batch, 256, 256]

        #第二层，使用vgg2特征
        decoder2_out, _ = self.decoder_gru2(decoder_norm1)#[batch,256,128]
        decoder_query2 = vgg2_seq.permute(1, 0, 2)   #[256,batch,128]
        decoder_key2 = decoder2_out.permute(1, 0, 2)       #[256,batch,128]
        decoder_value2 = decoder2_out.permute(1, 0, 2)     #[256,batch,128]
        decoder_attn_out2, _ = self.decoder_attn2(decoder_query2, decoder_key2, decoder_value2)
        decoder_attn_out2 = decoder_attn_out2.permute(1, 0, 2) #[batch,256,128]
        # 残差连接 + 归一化
        decoder_res2 = vgg2_seq + decoder_attn_out2
        decoder_norm2 = self.decoder_norm2(decoder_res2)
        # 特征融合
        fused2 = torch.cat([norm2, decoder_norm2], dim=-1)  # [batch, 256, 256]
        fused2 = self.fusion2(fused2)  # [batch, 256, 128]
        
        # 第三层解码：从vgg2回到vgg1
        fused2_adj = self._adjust_sequence_length(fused2, 1024)  # [batch, 1024, 128]
        decoder3_out, _ = self.decoder_gru3(fused2_adj)  # [batch, 1024, 64]
        
        # 解码器注意力：使用vgg1特征
        decoder_query3 = vgg1_seq.permute(1, 0, 2)  # [1024, batch, 64]
        decoder_key3 = decoder3_out.permute(1, 0, 2)        # [1024, batch, 64]
        decoder_value3 = decoder3_out.permute(1, 0, 2)      # [1024, batch, 64]
        decoder_attn_out3, _ = self.decoder_attn3(decoder_query3, decoder_key3, decoder_value3)
        decoder_attn_out3 = decoder_attn_out3.permute(1, 0, 2)  # [batch, 1024, 64]
        
        # 残差连接 + 归一化
        decoder_res3 = vgg1_seq + decoder_attn_out3
        decoder_norm3 = self.decoder_norm3(decoder_res3)
        # 特征融合
        fused3 = torch.cat([norm1, decoder_norm3], dim=-1)  # [batch, 1024, 128]
        fused3 = self.fusion3(fused3)  # [batch, 1024, 64]
        
         # 第四层解码：从vgg1回到输入
        decoder4_out, _ = self.decoder_gru4(fused3)  # [batch, 1024, 3]
        
        # 解码器注意力：使用输入特征
        decoder_query4 = x_seq.permute(1, 0, 2)  # [1024, batch, 3]
        decoder_key4 = decoder4_out.permute(1, 0, 2)           # [1024, batch, 3]
        decoder_value4 = decoder4_out.permute(1, 0, 2)        # [1024, batch, 3]
        decoder_attn_out4, _ = self.decoder_attn4(decoder_query4, decoder_key4, decoder_value4)
        decoder_attn_out4 = decoder_attn_out4.permute(1, 0, 2)  # [batch, 1024, 3]
        
        # 残差连接 + 归一化
        decoder_res4 = x_seq + decoder_attn_out4
        decoder_final = self.decoder_norm4(decoder_res4)  # [batch, 1024, 3]

         # === 最终特征融合和分类 ===
        # 使用编码器最终特征和解码器最终特征
        encoder_pooled = torch.mean(encoder_final, dim=1)  # [batch, 512]
        decoder_pooled = torch.mean(decoder_final, dim=1)  # [batch, 3]
        
        # 融合两种特征
        final_features = torch.cat([encoder_pooled, decoder_pooled], dim=1)  # [batch, 515]

        # 分类器
        x=self.relu(self.fc1(final_features))
        x = self.fc2(x)
        x=self.fc3(x)
        return x

In [9]:
import matplotlib.pyplot as plt
model=build_model().to(device)
LR=0.0001
criterion=nn.CrossEntropyLoss().cuda()
optimizer=optim.Adam(model.parameters(),LR,weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, 
    T_0=10,     # 第一次重启的周期
    T_mult=2,   # 每次重启后周期倍增
    eta_min=1e-6  # 最小学习率
)
max_grad_norm = 1.0  # 梯度裁剪阈值
# 添加图像描述变量
image_captions = []

# 创建存储训练指标的列表
train_losses = []
val_losses = []  # 新增验证损失列表
val_accuracies = []
learning_rates = []
epochs=30
for epoch in range(epochs):
    running_loss=0.0
    total=0
    correct=0
    model.train()
    # 添加当前epoch的图像描述
    epoch_caption = f"Epoch {epoch+1}/{epochs}: Training in progress..."
    image_captions.append(epoch_caption)
    for i,data in enumerate(trainloader,0):
        inputs,labels=data
        inputs=inputs.cuda()
        labels=labels.cuda()
        optimizer.zero_grad()
        outputs=model(inputs)
        loss=criterion(outputs,labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        running_loss+=loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if i % 2000 == 1999:  # 每2000个batch打印一次
            train_acc = 100 * correct / total
            batch_caption = (f'Epoch [{epoch+1}/{epochs}], Batch [{i+1}/{len(trainloader)}], '
                            f'Loss: {running_loss/2000:.4f}, Acc: {train_acc:.2f}%')
            print(batch_caption)
            image_captions.append(batch_caption)
            running_loss = 0.0
            total = 0
            correct = 0
            
        # 更新学习率
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        learning_rates.append(current_lr)

     # 验证过程
    model.eval()
    val_correct = 0
    val_total = 0
    val_loss = 0.0
    
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_acc = 100 * val_correct / val_total
    avg_val_loss = val_loss / len(testloader)
    val_losses.append(avg_val_loss)  # 记录验证损失
    val_accuracies.append(val_acc)
    
    epoch_summary = (f'Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}, '
                    f'Accuracy: {val_acc:.2f}%, LR: {current_lr:.6f}')
    print(epoch_summary)
    image_captions.append(epoch_summary)
    train_losses.append(val_loss)


print('Finished Training')

Epoch [1/30], Batch [2000/12500], Loss: 1.5703, Acc: 43.40%
Epoch [1/30], Batch [4000/12500], Loss: 1.1840, Acc: 58.25%
Epoch [1/30], Batch [6000/12500], Loss: 1.0559, Acc: 63.24%
Epoch [1/30], Batch [8000/12500], Loss: 1.0226, Acc: 65.10%
Epoch [1/30], Batch [10000/12500], Loss: 0.8487, Acc: 71.09%
Epoch [1/30], Batch [12000/12500], Loss: 0.9959, Acc: 65.64%
Epoch 1, Validation Loss: 0.9343, Accuracy: 68.87%, LR: 0.000088
Epoch [2/30], Batch [2000/12500], Loss: 0.9316, Acc: 68.26%
Epoch [2/30], Batch [4000/12500], Loss: 0.8380, Acc: 72.29%
Epoch [2/30], Batch [6000/12500], Loss: 0.7598, Acc: 74.40%
Epoch [2/30], Batch [8000/12500], Loss: 0.7036, Acc: 76.55%
Epoch [2/30], Batch [10000/12500], Loss: 0.9390, Acc: 68.92%
Epoch [2/30], Batch [12000/12500], Loss: 0.9176, Acc: 69.08%
Epoch 2, Validation Loss: 0.8645, Accuracy: 72.88%, LR: 0.000089
Epoch [3/30], Batch [2000/12500], Loss: 0.8520, Acc: 72.05%
Epoch [3/30], Batch [4000/12500], Loss: 0.8196, Acc: 73.06%
Epoch [3/30], Batch [6000/

In [10]:
# 保存最终模型
torch.save(model.state_dict(), 'multiheadgru_model.pth')
final_caption = "Training completed. Model saved as final_model.pth"
print(final_caption)
image_captions.append(final_caption)

# 绘制测试集准确率变化折线图
plt.figure(figsize=(10, 6))
plt.plot(range(1, epochs+1), val_accuracies, 'b-o', linewidth=2)
plt.title('Test Accuracy Over Epochs', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.xticks(range(1, epochs+1))
plt.ylim(0, 100)  # 确保y轴从0到100%

# 标记最高准确率
max_acc = max(val_accuracies)
max_epoch = val_accuracies.index(max_acc) + 1
plt.annotate(f'Max: {max_acc:.2f}%', 
             xy=(max_epoch, max_acc),
             xytext=(max_epoch+1, max_acc-5),
             arrowprops=dict(facecolor='red', shrink=0.05),
             fontsize=12)

# 保存图表
plt.savefig('accuracy_plot_multiheadgru_3.png', dpi=300, bbox_inches='tight')
plt.close()

# 添加图表描述
plot_caption = ("Accuracy Plot: Shows the model's performance improvement on the test set over training epochs. "
               f"Highest accuracy of {max_acc:.2f}% achieved at epoch {max_epoch}.")
image_captions.append(plot_caption)

# 保存所有图像描述到文件
with open('training_report.txt', 'w') as f:
    for caption in image_captions:
        f.write(caption + '\n')

print("Training report and accuracy plot saved successfully.")

Training completed. Model saved as final_model.pth
Training report and accuracy plot saved successfully.
