# import libraries

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import timm
import torch
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from torch import nn, optim
import os
import sys
from dataset import create_dataloaders  # 从dataset.py导入create_dataloaders函数
%matplotlib inline

print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA是否可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"当前GPU: {torch.cuda.get_device_name(0)}")

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号


# parameters

In [None]:
sys.path.append(os.path.abspath(".."))
# 设置数据路径
data_root = os.path.abspath(os.path.join("..","Aerial_Landscapes"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 可以尝试不同的增强策略进行训练
strategies = ['default', 'minimal', 'extensive', 'new']
num_classes = 15
# 设置训练参数
batch_size = 64
num_epochs = 5
learning_rate = 0.0001
class_weights = torch.ones(15)

# 增加容易混淆类别的权重
class_weights[0] = 1.2
class_weights[6] = 1.2
class_weights[3] = 1.2
class_weights[13] = 1.2
class_weights[11] = 1.5
class_weights[8] = 1.5
class_weights[14] = 1.2
class_weights = class_weights.to(device)

In [None]:
def create_model(num_classes=15, mode='train',print_structure=False):
    if mode == 'train':
        # 使用更大的模型版本
        model = timm.create_model('vit_base_patch16_224', pretrained=True, drop_rate=0.15, attn_drop_rate=0.1)

        # 添加更复杂的分类头
        in_features = model.head.in_features
        model.head = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    else:
        model = timm.create_model('vit_base_patch16_224', pretrained=False)
        in_features = model.head.in_features
        model.head = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    if print_structure:
        print(model)
    return model
# 可视化混淆矩阵
def plot_confusion_matrix(conf_matrix, classes):
    print("绘制混淆矩阵...")
    plt.figure(figsize=(15, 12))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.show()
    print(f"混淆矩阵已保存为 'confusion_matrix.png'")



In [None]:
# # 使用dataset.py创建数据加载器
# print("\n正在加载数据...")
# try:
#     # 使用导入的create_dataloaders函数，指定增强策略为"default"
#     train_loader, val_loader, test_loader, classes = create_dataloaders(
#         root_dir=data_root,
#         batch_size=batch_size,
#         split_ratio=[0.6, 0.2, 0.2],
#         augmentation_strategy='minimal',
#         random_seed=42,
#         num_workers=0,
#         verbose=True
#     )
#     print(f"数据加载完成。类别数: {len(classes)}, 类别: {classes}")
# except Exception as e:
#     print(f"加载数据时出错: {e}")


In [None]:
# 定义损失函数和优化器
# 使用基于混淆矩阵的类权重

global_best_acc = 0.0
global_best_state = None
global_best_strategy = None

for st in strategies:
    print(f"\n=== Training with {st} ===")
    train_loader, val_loader, _, class_names = create_dataloaders(
        root_dir=data_root, batch_size=batch_size, augmentation_strategy=st, verbose=False
    )

    model = create_model(num_classes)
    model = model.to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    best_val_acc = 0.0
    best_state = None
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct = total = 0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        train_acc = correct / total
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)

        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_acc = correct / total
        print(f"Epoch {epoch+1}/10 | Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Val Loss: {val_loss:.4f}")

        scheduler.step(train_loss)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = model.state_dict()

    if best_val_acc > global_best_acc:
        global_best_acc = best_val_acc
        global_best_state = best_state
        global_best_strategy = st

    # 绘制训练曲线
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Metrics')
    plt.title(f'Training Curve ({st})')
    plt.legend()
    plt.show()

# 保存全局最佳模型
os.makedirs("saved_models_vit", exist_ok=True)
torch.save(global_best_state, f"saved_models_vit/best_vit_{global_best_strategy}.pth")
print(f"Global best: {global_best_strategy} @ {global_best_acc:.4f}")
#
# for st in strategies:
#     criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
#
#     optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
#
#     # 训练模型
#     model = create_model(num_classes)
#     model = model.to(device)
#     history = {
#         'train_loss': [],
#         'val_loss': [],
#         'train_acc': [],
#         'val_acc': [],
#         'epoch_time': [],
#         'batch_losses': []  # 保留这个记录，用于训练后的可视化
#     }
#
#     best_val_acc = 0.0
#     batch_losses = []
#     total_batch = 0
#
#     for epoch in range(num_epochs):
#         # 训练阶段
#         model.train()
#         train_loss = 0.0
#         correct = 0
#         total = 0
#
#         start_time = time.time()
#         for batch_idx, (inputs, labels) in enumerate(train_loader):
#             inputs, labels = inputs.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#
#             train_loss += loss.item() * inputs.size(0)
#             _, predicted = torch.max(outputs, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#
#             # 记录当前batch的损失
#             batch_losses.append(loss.item())
#             total_batch += 1
#
#             # 显示简单的进度条
#             if (batch_idx + 1) % 10 == 0 or batch_idx == len(train_loader) - 1:
#                 batch_acc = (predicted == labels).sum().item() / labels.size(0)
#                 print(f"Batch进度: {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}, 准确率: {batch_acc:.4f}",
#                       end='\r')
#
#         train_loss = train_loss / len(train_loader.dataset)
#         train_acc = correct / total
#
#         # 保存每个epoch结束时的所有batch损失
#         history['batch_losses'].extend(batch_losses[-len(train_loader):])
#
#         # 验证阶段
#         model.eval()
#         val_loss = 0.0
#         correct = 0
#         total = 0
#
#         with torch.no_grad():
#             for inputs, labels in val_loader:
#                 inputs, labels = inputs.to(device), labels.to(device)
#                 outputs = model(inputs)
#                 loss = criterion(outputs, labels)
#
#                 val_loss += loss.item() * inputs.size(0)
#                 _, predicted = torch.max(outputs, 1)
#                 total += labels.size(0)
#                 correct += (predicted == labels).sum().item()
#
#         val_loss = val_loss / len(val_loader.dataset)
#         val_acc = correct / total
#
#         # 学习率调整
#         current_lr = optimizer.param_groups[0]['lr']
#         scheduler.step(val_loss)
#         new_lr = optimizer.param_groups[0]['lr']
#
#         # 计算epoch时间
#         epoch_time = time.time() - start_time
#
#         # 保存最佳模型
#         is_best = val_acc > best_val_acc
#         if is_best:
#             best_val_acc = val_acc
#             torch.save(model.state_dict(), 'best_vit_model.pth')
#             best_mark = "✓ [最佳]"
#         else:
#             best_mark = ""
#
#         # 记录历史
#         history['train_loss'].append(train_loss)
#         history['val_loss'].append(val_loss)
#         history['train_acc'].append(train_acc)
#         history['val_acc'].append(val_acc)
#         history['epoch_time'].append(epoch_time)
#
#         # 美化打印输出
#         print(f"\n{'-' * 80}")
#         print(f"Epoch {epoch + 1}/{num_epochs} 完成 - 耗时: {epoch_time:.2f}秒 {best_mark}")
#         print(f"学习率: {current_lr:.8f} {'→ ' + str(new_lr) if current_lr != new_lr else ''}")
#         print(f"训练集 - Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f} ({correct}/{total})")
#         print(f"验证集 - Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")
#         if is_best:
#             print(f"✓ 新的最佳模型已保存! (验证准确率: {val_acc:.4f})")
#         print(f"{'-' * 80}")
#
#     print(f"\n{'-' * 80}")
#     print(f"训练完成! 最佳验证准确率: {best_val_acc:.4f}")
#     print(f"{'-' * 80}")
#



# 可视化训练历史

In [None]:
print("绘制训练历史图表...")
plt.figure(figsize=(15, 10))

# 绘制损失曲线
plt.subplot(2, 2, 1)
plt.plot(history['train_loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Curves')
plt.legend()

# 绘制准确率曲线
plt.subplot(2, 2, 2)
plt.plot(history['train_acc'], label='Training Accuracy')
plt.plot(history['val_acc'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy Curves')
plt.legend()

# 绘制每个batch的损失曲线
plt.subplot(2, 1, 2)
plt.plot(range(len(history['batch_losses'])), history['batch_losses'])
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.title('Loss per Batch')
plt.grid(True)

plt.tight_layout()
plt.savefig('training_history.png')
plt.show()

print(f"图表已保存为 'training_history.png'")


In [None]:
# 加载最佳模型

best_model = create_model(num_classes=num_classes, mode='test')
best_model.load_state_dict(torch.load('best_vit_model.pth'))
print("最佳模型加载完成")

_, _, test_loader, classes = create_dataloaders(
    root_dir=data_root,
    batch_size=batch_size,
    augmentation_strategy='minimal',
    split_ratio=[0.6, 0.2, 0.2],
    random_seed=42,
    num_workers=0,
    verbose=False
)
def test_model(model, test_loader, classes):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    y_true = []
    y_pred = []

    # 添加进度显示
    total_batches = len(test_loader)

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

            print(f"测试进度: {batch_idx + 1}/{total_batches} 批次", end='\r')

    # 计算混淆矩阵
    conf_matrix = confusion_matrix(y_true, y_pred)

    # 计算分类报告
    report = classification_report(y_true, y_pred, target_names=classes)

    # 计算总体准确率
    accuracy = accuracy_score(y_true, y_pred)

    print(f"\n测试完成! 总体准确率: {accuracy:.4f}")

    return conf_matrix, report, accuracy, y_true, y_pred

# 执行测试
print("\n正在测试模型...")
conf_matrix, report, accuracy, y_true, y_pred = test_model(best_model, test_loader, classes)
# 打印测试结果
print(f"\n测试准确率: {accuracy:.4f}")
print("\n分类报告:")
print(report)
print("\n混淆矩阵:")
print(conf_matrix)



In [None]:
# 可视化混淆矩阵
plot_confusion_matrix(conf_matrix, classes)

# 打印所有错误的预测


In [None]:
# 可视化错误预测的图像
def visualize_misclassifications(model, test_loader, classes, y_true, y_pred, max_images=100):
    """
    可视化测试集中被错误分类的图像

    Args:
        model: 训练好的模型
        test_loader: 测试数据加载器
        classes: 类别名称列表
        y_true: 真实标签列表
        y_pred: 预测标签列表
        max_images: 最多显示的错误图像数量
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # 找出所有错误预测的索引
    misclassified_indices = [i for i, (y_t, y_p) in enumerate(zip(y_true, y_pred)) if y_t != y_p]
    print(f"共有 {len(misclassified_indices)} 个错误预测")

    if len(misclassified_indices) == 0:
        print("没有错误预测！模型表现完美")
        return

    # 限制显示的图像数量
    if len(misclassified_indices) > max_images:
        print(f"仅显示前 {max_images} 个错误预测")
        misclassified_indices = misclassified_indices[:max_images]

    # 创建一个字典，将测试加载器的批次索引映射到真实图像和标签
    all_images = []
    all_labels = []

    print("收集测试数据...")
    with torch.no_grad():
        for inputs, labels in test_loader:
            # 将批次中的每个图像添加到列表中
            for i in range(inputs.size(0)):
                all_images.append(inputs[i].cpu())
                all_labels.append(labels[i].item())

    # 检查索引是否超出范围
    valid_indices = [i for i in misclassified_indices if i < len(all_images)]
    if len(valid_indices) < len(misclassified_indices):
        print(f"警告：有 {len(misclassified_indices) - len(valid_indices)} 个索引超出范围")

    # 计算需要的行数和列数
    n_cols = 5
    n_rows = (len(valid_indices) + n_cols - 1) // n_cols

    plt.figure(figsize=(20, 4 * n_rows))

    for i, idx in enumerate(valid_indices):
        # 获取图像和标签
        img = all_images[idx]
        true_label = y_true[idx]
        pred_label = y_pred[idx]

        # 转换图像格式用于显示
        img = img.numpy().transpose((1, 2, 0))
        # 反标准化
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)

        # 绘制图像
        plt.subplot(n_rows, n_cols, i + 1)
        plt.imshow(img)
        plt.title(f"真实: {classes[true_label]}\n预测: {classes[pred_label]}", color='red')
        plt.axis('off')

    plt.tight_layout()
    plt.savefig('misclassified_images.png')
    plt.show()
    print(f"错误预测图像已保存为 'misclassified_images.png'")

# 调用函数，显示错误预测的图像
print("\n可视化错误预测...")
visualize_misclassifications(best_model, test_loader, classes, y_true, y_pred)


# Grad-CAM可视化模型决策


In [None]:
from pytorch_grad_cam import GradCAM, XGradCAM, GradCAMPlusPlus, AblationCAM, EigenCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.reshape_transforms import vit_reshape_transform
def visualize_gradcam_samples(model, test_loader, classes, y_true, y_pred, num_images=5):
    """
    使用GradCAM可视化模型在正确和错误预测样本上的关注区域

    Args:
        model: 训练好的模型
        test_loader: 测试数据加载器
        classes: 类别名称列表
        y_true: 真实标签列表
        y_pred: 预测标签列表
        num_images: 每类(正确/错误)要显示的图像数量
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    # 找出错误和正确预测的索引
    misclassified_indices = [i for i, (y_t, y_p) in enumerate(zip(y_true, y_pred)) if y_t != y_p]
    correctly_classified_indices = [i for i, (y_t, y_p) in enumerate(zip(y_true, y_pred)) if y_t == y_p]

    print(f"找到 {len(misclassified_indices)} 个错误预测和 {len(correctly_classified_indices)} 个正确预测")

    # 随机选择指定数量的样本
    if len(misclassified_indices) > num_images:
        misclassified_indices = np.random.choice(misclassified_indices, num_images, replace=False)

    if len(correctly_classified_indices) > num_images:
        correctly_classified_indices = np.random.choice(correctly_classified_indices, num_images, replace=False)

    # 创建GradCAM对象
    # 对于ViT模型，通常使用最后一个transformer块中的层
    target_layers = [model.blocks[-1].norm1]  # 最后一个transformer块的第一个归一化层

    # 初始化GradCAM - 移除了use_cuda参数
    cam = GradCAM(
        model=model,
        target_layers=target_layers,
        reshape_transform=vit_reshape_transform
    )

    # 收集所有图像和标签
    all_images = []
    all_labels = []

    print("收集测试数据...")
    with torch.no_grad():
        for inputs, labels in test_loader:
            # 添加每个批次的图像
            for i in range(inputs.size(0)):
                all_images.append(inputs[i].cpu())
                all_labels.append(labels[i].item())

    # 设置图像行数和列数
    n_rows = 2  # 错误预测和正确预测
    n_cols = min(num_images, len(misclassified_indices), len(correctly_classified_indices))
    
    # 每个示例需要2列：原图和热力图
    plt.figure(figsize=(n_cols * 6, n_rows * 4))

    # 可视化错误预测
    print("生成错误预测的GradCAM...")
    for i, idx in enumerate(misclassified_indices[:n_cols]):
        # 获取图像和标签
        img_tensor = all_images[idx].unsqueeze(0).to(device)
        true_label = y_true[idx]
        pred_label = y_pred[idx]

        # 应用GradCAM
        targets = [ClassifierOutputTarget(pred_label)]  # 使用预测的类别
        grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]

        # 转换图像用于显示
        img = img_tensor.cpu().numpy().squeeze().transpose((1, 2, 0))
        # 反标准化
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)

        # 将GradCAM叠加到图像上
        visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)

        # 子图索引计算
        orig_idx = i * 2 + 1
        cam_idx = i * 2 + 2
        
        # 显示原始图像
        plt.subplot(n_rows, n_cols * 2, orig_idx)
        plt.imshow(img)
        plt.title(f"错误预测 - 原图\n真实: {classes[true_label]}\n预测: {classes[pred_label]}", color='red', fontsize=9)
        plt.axis('off')
        
        # 显示GradCAM热力图
        plt.subplot(n_rows, n_cols * 2, cam_idx)
        plt.imshow(visualization)
        plt.title(f"错误预测 - 热力图\n真实: {classes[true_label]}\n预测: {classes[pred_label]}", color='red', fontsize=9)
        plt.axis('off')

    # 可视化正确预测
    print("生成正确预测的GradCAM...")
    for i, idx in enumerate(correctly_classified_indices[:n_cols]):
        # 获取图像和标签
        img_tensor = all_images[idx].unsqueeze(0).to(device)
        true_label = y_true[idx]

        # 应用GradCAM
        targets = [ClassifierOutputTarget(true_label)]  # 使用真实的类别
        grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
        grayscale_cam = grayscale_cam[0, :]

        # 转换图像用于显示
        img = img_tensor.cpu().numpy().squeeze().transpose((1, 2, 0))
        # 反标准化
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)

        # 将GradCAM叠加到图像上
        visualization = show_cam_on_image(img, grayscale_cam, use_rgb=True)

        # 子图索引计算
        row_offset = n_cols * 2  # 第二行的偏移量
        orig_idx = row_offset + i * 2 + 1
        cam_idx = row_offset + i * 2 + 2
        
        # 显示原始图像
        plt.subplot(n_rows, n_cols * 2, orig_idx)
        plt.imshow(img)
        plt.title(f"正确预测 - 原图\n类别: {classes[true_label]}", color='green', fontsize=9)
        plt.axis('off')
        
        # 显示GradCAM热力图
        plt.subplot(n_rows, n_cols * 2, cam_idx)
        plt.imshow(visualization)
        plt.title(f"正确预测 - 热力图\n类别: {classes[true_label]}", color='green', fontsize=9)
        plt.axis('off')

    plt.tight_layout()
    plt.savefig('gradcam_visualization.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"GradCAM可视化已保存为 'gradcam_visualization.png'")

# 调用函数
print("\n使用GradCAM可视化模型决策...")
visualize_gradcam_samples(best_model, test_loader, classes, y_true, y_pred, num_images=5)
