In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc, precision_recall_curve, average_precision_score, f1_score
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import joblib
import seaborn as sns

# 自定义数据集类
class PlantDocDataset(Dataset):
    def __init__(self, root_dir, txt_path, transform=None, train=True, train_ratio=0.8, random_seed=42):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # 解析txt文件
        with open(txt_path, 'r') as f:
            lines = f.readlines()

        for line in lines:
            line = line.strip()
            if not line:
                continue
            parts = line.split('=')
            if len(parts) < 3:
                continue
            img_rel_path, label_str, _ = parts[0], parts[1], parts[2]
            img_full_path = os.path.join(root_dir, 'images', img_rel_path.replace('/', os.path.sep))
            if not os.path.exists(img_full_path):
                continue
            label = int(label_str)
            self.samples.append((img_full_path, label))

        # 随机分割数据集
        num_samples = len(self.samples)
        indices = list(range(num_samples))
        np.random.seed(random_seed)
        np.random.shuffle(indices)
        split_idx = int(train_ratio * num_samples)
        self.indices = indices[:split_idx] if train else indices[split_idx:]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        img_path, label = self.samples[actual_idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 获取类别数
def get_num_classes(txt_path):
    labels = set()
    with open(txt_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            if not line:
                continue
            parts = line.split('=')
            if len(parts) < 3:
                continue
            _, label_str, _ = parts[0], parts[1], parts[2]
            labels.add(int(label_str))
    return len(labels)

# 初始化配置
root_dir = r'E:\data1\plantdoc'
txt_path = r'E:\data1\plantdoc\trainval.txt'

# 获取类别数
num_classes = get_num_classes(txt_path)

# 创建数据集和数据加载器
train_dataset = PlantDocDataset(
    root_dir=root_dir,
    txt_path=txt_path,
    transform=transform,
    train=True
)

test_dataset = PlantDocDataset(
    root_dir=root_dir,
    txt_path=txt_path,
    transform=transform,
    train=False
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# 模型定义
def load_efficientnet_b4(pretrained=True):
    model = models.efficientnet_b4(pretrained=pretrained)
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_ftrs, num_classes)  # 修改输出层以匹配类别数
    return model

def load_convnext_tiny(pretrained=True):
    model = models.convnext_tiny(pretrained=pretrained)
    num_ftrs = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(num_ftrs, num_classes)  # 修改输出层以匹配类别数
    return model

class DS_Fusionnet(nn.Module):
    def __init__(self, num_classes):
        super(DS_Fusionnet, self).__init__()
        self.efficientnet = load_efficientnet_b4(pretrained=True)
        self.convnext = load_convnext_tiny(pretrained=True)
        
        # Get the output dimensions of the teacher models
        sample_input = torch.randn(1, 3, 224, 224)
        eff_out_dim = self.efficientnet(sample_input).shape[1]
        convnext_out_dim = self.convnext(sample_input).shape[1]
        
        self.fc = nn.Linear(eff_out_dim + convnext_out_dim, num_classes)

    def forward(self, x):
        eff_out = self.efficientnet(x)
        convnext_out = self.convnext(x)
        fused_out = torch.cat((eff_out, convnext_out), dim=1)
        out = self.fc(fused_out)
        return out

# 双向知识蒸馏损失函数
def distillation_loss(y, labels, teacher_scores1, teacher_scores2, T=3, alpha=0.5):
    KD_loss1 = nn.KLDivLoss(reduction='batchmean')(nn.functional.log_softmax(y / T, dim=1),
                                                  nn.functional.softmax(teacher_scores1 / T, dim=1)) * (alpha * T * T)
    KD_loss2 = nn.KLDivLoss(reduction='batchmean')(nn.functional.log_softmax(y / T, dim=1),
                                                  nn.functional.softmax(teacher_scores2 / T, dim=1)) * (alpha * T * T)
    CE_loss = nn.CrossEntropyLoss()(y, labels) * (1. - alpha)
    total_KD_loss = (KD_loss1 + KD_loss2) / 2
    return total_KD_loss + CE_loss

# 训练函数
def train(model, teacher_model1, teacher_model2, device, train_loader, optimizer, epoch, accumulation_steps=4):
    model.train()
    teacher_model1.eval()
    teacher_model2.eval()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch}')):
        data, target = data.to(device), target.to(device)
        output_student = model(data)
        with torch.no_grad():
            output_teacher1 = teacher_model1(data)
            output_teacher2 = teacher_model2(data)
        loss = distillation_loss(output_student, target, output_teacher1, output_teacher2) / accumulation_steps
        
        loss.backward()
        
        if ((batch_idx + 1) % accumulation_steps == 0) or (batch_idx + 1 == len(train_loader)):
            optimizer.step()
            optimizer.zero_grad()
            running_loss += loss.item() * accumulation_steps
    
    avg_train_loss = running_loss / len(train_loader.dataset)
    print(f'\nTrain set: Average loss: {avg_train_loss:.4f}')
    return avg_train_loss

# 测试函数
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc='Testing'):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += nn.CrossEntropyLoss(reduction='sum')(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(target.cpu().numpy())

    test_loss /= len(test_loader.dataset)

    accuracy = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), accuracy))
    
    return test_loss, accuracy, all_preds, all_labels

# 特征提取函数
def extract_features(model, dataloader, device):
    model.eval()
    features = []
    labels = []

    # Forward pass through EfficientNet to get features
    with torch.no_grad():
        for i, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            eff_output = model.efficientnet(data)
            eff_features = torch.flatten(eff_output, start_dim=1).cpu().numpy()
            features.append(eff_features)
            labels.extend(target.cpu().numpy())
            
            # 打印每批数据的特征和标签数量以便调试
            print(f"Batch {i}: Features shape: {len(features)}, Labels shape: {len(labels)}")

    # Concatenate all features
    features = np.concatenate(features)
    labels = np.array(labels)
    return features, labels

# 添加噪声函数
def add_noise(features, noise_level=0.1):
    noise = np.random.normal(0, noise_level, features.shape)
    noisy_features = features + noise
    return noisy_features

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载教师模型
teacher_model1 = load_efficientnet_b4(pretrained=True).to(device)
teacher_model2 = load_convnext_tiny(pretrained=True).to(device)

# 初始化学生模型
student_model = DS_Fusionnet(num_classes=num_classes).to(device)

# 定义优化器
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# 训练学生模型
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(student_model, teacher_model1, teacher_model2, device, train_loader, optimizer, epoch)

# 测试学生模型
_, _, student_preds, student_labels = test(student_model, device, test_loader)

# 提取测试集特征
features, labels = extract_features(student_model, test_loader, device)

# 打印特征和标签的形状以便调试
print(f"Features shape: {features.shape}")
print(f"Labels shape: {labels.shape}")

# 确保特征和标签的数量相同
assert features.shape[0] == labels.shape[0], "Number of features and labels must be the same"

# 添加噪声
noisy_features = add_noise(features)

# 标签二值化
y_true_bin = label_binarize(labels, classes=np.arange(num_classes))

# 使用多核SVM进行多分类
svm_rbf_classifier = OneVsRestClassifier(SVC(kernel='rbf', probability=True, random_state=42))
svm_linear_classifier = OneVsRestClassifier(SVC(kernel='linear', probability=True, random_state=42))

svm_rbf_classifier.fit(noisy_features, y_true_bin)
svm_linear_classifier.fit(noisy_features, y_true_bin)

svm_rbf_preds = svm_rbf_classifier.predict(noisy_features)
svm_linear_preds = svm_linear_classifier.predict(noisy_features)

# 获取预测概率
svm_rbf_probs = svm_rbf_classifier.predict_proba(noisy_features)
svm_linear_probs = svm_linear_classifier.predict_proba(noisy_features)

# 计算各种评估指标
def calculate_metrics(true_labels, predicted_labels, predicted_probs):
    report = classification_report(np.argmax(true_labels, axis=1), predicted_labels, output_dict=True)
    cm = confusion_matrix(np.argmax(true_labels, axis=1), predicted_labels)
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    precision = dict()
    recall = dict()
    ap = dict()
    f1 = f1_score(np.argmax(true_labels, axis=1), predicted_labels, average='weighted')
    
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(true_labels[:, i], predicted_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
        precision[i], recall[i], _ = precision_recall_curve(true_labels[:, i], predicted_probs[:, i])
        ap[i] = average_precision_score(true_labels[:, i], predicted_probs[:, i])
    
    # Calculate overall metrics
    fpr["micro"], tpr["micro"], _ = roc_curve(true_labels.ravel(), predicted_probs.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    precision["micro"], recall["micro"], _ = precision_recall_curve(true_labels.ravel(), predicted_probs.ravel())
    ap["micro"] = average_precision_score(true_labels, predicted_probs, average='micro')
    
    # Calculate macro metrics
    all_fpr = np.unique(np.concatenate([fpr[i] for i in range(num_classes)]))
    mean_tpr = np.zeros_like(all_fpr)
    mean_precision = np.zeros_like(all_fpr)
    for i in range(num_classes):
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
        mean_precision += np.interp(all_fpr, precision[i], recall[i])
    mean_tpr /= num_classes
    mean_precision /= num_classes
    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    precision["macro"] = mean_precision
    recall["macro"] = mean_precision
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
    ap["macro"] = average_precision_score(true_labels, predicted_probs, average='macro')
    
    return report, cm, fpr, tpr, roc_auc, precision, recall, ap, f1

metrics_svm_rbf = calculate_metrics(y_true_bin, np.argmax(svm_rbf_preds, axis=1), svm_rbf_probs)
metrics_svm_linear = calculate_metrics(y_true_bin, np.argmax(svm_linear_preds, axis=1), svm_linear_probs)
metrics_ds_fusionnet = calculate_metrics(y_true_bin, student_preds, svm_rbf_probs)  # Using SVM RBF probs for comparison

# 打印分类报告
print("Classification Report - SVM RBF:")
print(classification_report(np.argmax(y_true_bin, axis=1), np.argmax(svm_rbf_preds, axis=1)))

print("\nClassification Report - SVM Linear:")
print(classification_report(np.argmax(y_true_bin, axis=1), np.argmax(svm_linear_preds, axis=1)))

print("\nClassification Report - DS FusionNet:")
print(classification_report(np.argmax(y_true_bin, axis=1), student_preds))

# 绘制条形图代替混淆矩阵
def plot_class_accuracy(report, title, ax):
    class_accs = [report[str(i)]['precision'] for i in range(num_classes)]
    ax.bar(range(num_classes), class_accs, color='skyblue')
    ax.set_title(title)
    ax.set_xlabel('Class Index')
    ax.set_ylabel('Precision')
    ax.set_xticks([])
    ax.set_xlim([0, num_classes])
    ax.tick_params(axis='both', which='major', labelsize=8)
    ax.grid(True, linestyle='--', alpha=0.7)

plt.figure(figsize=(18, 6))

ax1 = plt.subplot(1, 3, 1)
plot_class_accuracy(metrics_svm_rbf[0], 'Class Precision - SVM RBF', ax1)

ax2 = plt.subplot(1, 3, 2)
plot_class_accuracy(metrics_svm_linear[0], 'Class Precision - SVM Linear', ax2)

ax3 = plt.subplot(1, 3, 3)
plot_class_accuracy(metrics_ds_fusionnet[0], 'Class Precision - DS FusionNet', ax3)

plt.tight_layout()

plt.show()

# 合并绘制ROC曲线
plt.figure(figsize=(12, 6))

# Plot SVM RBF ROC curve
plt.plot(metrics_svm_rbf[2]["micro"], metrics_svm_rbf[3]["micro"],
         label=f'SVM RBF micro-average ROC curve (area = {metrics_svm_rbf[4]["micro"]:0.2f})',
         color='deeppink', linestyle='-', linewidth=2)

# Plot SVM RBF macro-average ROC curve
plt.plot(metrics_svm_rbf[2]["macro"], metrics_svm_rbf[3]["macro"],
         label=f'SVM RBF macro-average ROC curve (area = {metrics_svm_rbf[4]["macro"]:0.2f})',
         color='deeppink', linestyle=':', linewidth=2)

# Plot SVM Linear ROC curve
plt.plot(metrics_svm_linear[2]["micro"], metrics_svm_linear[3]["micro"],
         label=f'SVM Linear micro-average ROC curve (area = {metrics_svm_linear[4]["micro"]:0.2f})',
         color='navy', linestyle='-', linewidth=2)

# Plot SVM Linear macro-average ROC curve
plt.plot(metrics_svm_linear[2]["macro"], metrics_svm_linear[3]["macro"],
         label=f'SVM Linear macro-average ROC curve (area = {metrics_svm_linear[4]["macro"]:0.2f})',
         color='navy', linestyle=':', linewidth=2)

# Plot DS FusionNet ROC curve
plt.plot(metrics_ds_fusionnet[2]["micro"], metrics_ds_fusionnet[3]["micro"],
         label=f'DS FusionNet micro-average ROC curve (area = {metrics_ds_fusionnet[4]["micro"]:0.2f})',
         color='green', linestyle='-', linewidth=2)

# Plot DS FusionNet macro-average ROC curve
plt.plot(metrics_ds_fusionnet[2]["macro"], metrics_ds_fusionnet[3]["macro"],
         label=f'DS FusionNet macro-average ROC curve (area = {metrics_ds_fusionnet[4]["macro"]:0.2f})',
         color='green', linestyle=':', linewidth=2)

plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic - Micro and Macro Averages')
plt.legend(loc="lower right", bbox_to_anchor=(1.05, 0))
plt.tight_layout(rect=[0, 0, 0.95, 1])  # Adjust rect parameter to make space for legend

plt.show()

# 合并绘制Precision-Recall曲线
plt.figure(figsize=(12, 6))

# Plot SVM RBF PR curve
plt.plot(metrics_svm_rbf[5]["micro"], metrics_svm_rbf[6]["micro"],
         label=f'SVM RBF micro-average PR curve (AP = {metrics_svm_rbf[7]["micro"]:0.2f})',
         color='deeppink', linestyle='-', linewidth=2)

# Plot SVM RBF macro-average PR curve
plt.plot(metrics_svm_rbf[5]["macro"], metrics_svm_rbf[6]["macro"],
         label=f'SVM RBF macro-average PR curve (AP = {metrics_svm_rbf[7]["macro"]:0.2f})',
         color='deeppink', linestyle=':', linewidth=2)

# Plot SVM Linear PR curve
plt.plot(metrics_svm_linear[5]["micro"], metrics_svm_linear[6]["micro"],
         label=f'SVM Linear micro-average PR curve (AP = {metrics_svm_linear[7]["micro"]:0.2f})',
         color='navy', linestyle='-', linewidth=2)

# Plot SVM Linear macro-average PR curve
plt.plot(metrics_svm_linear[5]["macro"], metrics_svm_linear[6]["macro"],
         label=f'SVM Linear macro-average PR curve (AP = {metrics_svm_linear[7]["macro"]:0.2f})',
         color='navy', linestyle=':', linewidth=2)

# Plot DS FusionNet PR curve
plt.plot(metrics_ds_fusionnet[5]["micro"], metrics_ds_fusionnet[6]["micro"],
         label=f'DS FusionNet micro-average PR curve (AP = {metrics_ds_fusionnet[7]["micro"]:0.2f})',
         color='green', linestyle='-', linewidth=2)

# Plot DS FusionNet macro-average PR curve
plt.plot(metrics_ds_fusionnet[5]["macro"], metrics_ds_fusionnet[6]["macro"],
         label=f'DS FusionNet macro-average PR curve (AP = {metrics_ds_fusionnet[7]["macro"]:0.2f})',
         color='green', linestyle=':', linewidth=2)

plt.fill_between(metrics_svm_rbf[5]["micro"], 0, metrics_svm_rbf[6]["micro"], alpha=0.2, color='deeppink')
plt.fill_between(metrics_svm_linear[5]["micro"], 0, metrics_svm_linear[6]["micro"], alpha=0.2, color='navy')
plt.fill_between(metrics_ds_fusionnet[5]["micro"], 0, metrics_ds_fusionnet[6]["micro"], alpha=0.2, color='green')

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve - Micro and Macro Averages')
plt.legend(loc="best", bbox_to_anchor=(1.05, 0))
plt.tight_layout(rect=[0, 0, 0.95, 1])  # Adjust rect parameter to make space for legend

plt.show()

# 绘制模型性能汇总图
model_names = ["SVM RBF", "SVM Linear", "DS FusionNet"]
accuracies = [
    metrics_svm_rbf[0]['accuracy'],
    metrics_svm_linear[0]['accuracy'],
    metrics_ds_fusionnet[0]['accuracy']
]
f1_scores = [
    metrics_svm_rbf[8],
    metrics_svm_linear[8],
    metrics_ds_fusionnet[8]
]

x = np.arange(len(model_names))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
rects1 = ax.bar(x - width/2, accuracies, width, label='Accuracy')
rects2 = ax.bar(x + width/2, f1_scores, width, label='F1 Score')

ax.set_xlabel('Models')
ax.set_ylabel('Scores')
ax.set_title('Model Performance Comparison')
ax.set_xticks(x)
ax.set_xticklabels(model_names)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

# Add value labels on top of bars
def autolabel(rects):
    """Attach a text label above each bar in *rects*, displaying its height."""
    for rect in rects:
        height = rect.get_height()
        ax.annotate('{:.2f}'.format(height),
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom')

autolabel(rects1)
autolabel(rects2)

fig.tight_layout()

plt.show()

# 降维并可视化特征
pca = PCA(n_components=2)
reduced_features_pca = pca.fit_transform(noisy_features)

tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
reduced_features_tsne = tsne.fit_transform(noisy_features)

# 绘制PCA降维结果
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
scatter = plt.scatter(reduced_features_pca[:, 0], reduced_features_pca[:, 1], c=labels, cmap='tab20c', edgecolor='k', s=50, alpha=0.7)
plt.title('PCA Visualization')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
legend1 = plt.legend(*scatter.legend_elements(), title="Classes", loc='best', bbox_to_anchor=(1.05, 1))
plt.gca().add_artist(legend1)

# 绘制t-SNE降维结果
plt.subplot(1, 2, 2)
scatter = plt.scatter(reduced_features_tsne[:, 0], reduced_features_tsne[:, 1], c=labels, cmap='tab20c', edgecolor='k', s=50, alpha=0.7)
plt.title('t-SNE Visualization')
plt.xlabel('Component 1')
plt.ylabel('Component 2')
legend2 = plt.legend(*scatter.legend_elements(), title="Classes", loc='best', bbox_to_anchor=(1.05, 1))
plt.gca().add_artist(legend2)

plt.tight_layout(rect=[0, 0, 0.95, 1])  # Adjust rect parameter to make space for legend

plt.show()

# 保存模型
# SVM RBF
joblib.dump(svm_rbf_classifier, 'PW_svm_rbf_classifier.pkl')
# SVM Linear
joblib.dump(svm_linear_classifier, 'PW_svm_linear_classifier.pkl')
# DS FusionNet
torch.save(student_model.state_dict(), 'PW_ds_fusionnet.pth')



