In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, silhouette_score
from scipy.optimize import linear_sum_assignment
from PIL import Image
import os

# 改进的数据集类（支持半监督设置）
class SemiSupervisedPlantDoc(Dataset):
    def __init__(self, root_dir, labeled_ratio=0.1, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.labeled_samples = []
        self.unlabeled_samples = []
        
        # 获取所有类别
        class_to_idx = {}
        for idx, cls in enumerate(os.listdir(os.path.join(root_dir, 'train'))):
            class_to_idx[cls] = idx
        
        # 解析标记数据
        for cls in class_to_idx:
            class_dir = os.path.join(root_dir, 'train', cls)
            images = os.listdir(class_dir)
            split_idx = int(len(images) * labeled_ratio)
            
            for i, img_name in enumerate(images):
                img_path = os.path.join(class_dir, img_name)
                label = class_to_idx[cls]
                
                if i < split_idx:  # 标记数据
                    self.labeled_samples.append((img_path, label))
                else:  # 未标记数据
                    self.unlabeled_samples.append((img_path, -1))  # -1表示未标记
                    
        # 合并标记和未标记数据
        self.samples = self.labeled_samples + self.unlabeled_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        is_labeled = (label != -1)
        return image, label, is_labeled

# 数据预处理配置
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 初始化数据集
data_root = 'E:\data1\PlantDisease'  # 根据实际路径修改
full_dataset = SemiSupervisedPlantDoc(
    root_dir=data_root,
    labeled_ratio=0.1,
    transform=train_transform
)

# 数据加载器（混合标记和未标记数据）
train_loader = DataLoader(full_dataset, batch_size=64, shuffle=True, num_workers=4)

# 加载DINOv2模型
class DINOv2Wrapper(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_path)
        self.projection = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
    def forward(self, x):
        features = self.model(x).last_hidden_state.mean(dim=1)
        return self.projection(features)

model = DINOv2Wrapper(r'E:\models\facebook_Dinov2')

# 对比损失函数
class GCDLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cross_entropy = nn.CrossEntropyLoss()
        
    def supervised_contrastive(self, features, labels):
        labels = labels.unsqueeze(1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        
        logits = torch.matmul(features, features.T) / self.temperature
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()
        
        exp_logits = torch.exp(logits)
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        
        loss = - (mask * log_prob).sum(1) / mask.sum(1)
        return loss.mean()
    
    def forward(self, features, labels, is_labeled):
        # 标记数据的监督对比损失
        labeled_mask = torch.where(is_labeled)[0]
        if len(labeled_mask) > 1:
            sup_loss = self.supervised_contrastive(
                features[labeled_mask], 
                labels[labeled_mask]
            )
        else:
            sup_loss = torch.tensor(0.0)
            
        # 未标记数据的噪声对比估计
        logits = torch.matmul(features, features.T) / self.temperature
        labels = torch.arange(logits.size(0)).to(logits.device)
        
        nce_loss = (self.cross_entropy(logits, labels) + 
                   self.cross_entropy(logits.T, labels)) / 2
        
        return 0.5*sup_loss + 0.5*nce_loss

# 半监督聚类模块
class SemiSupervisedKMeans:
    def __init__(self, n_known_classes):
        self.n_known = n_known_classes
        self.kmeans = KMeans()
        
    def fit(self, features, labels=None):
        # 使用标记数据初始化已知类中心
        known_features = features[labels != -1]
        known_labels = labels[labels != -1]
        
        # 计算已知类中心
        known_centers = []
        for c in range(self.n_known):
            mask = (known_labels == c)
            if mask.sum() > 0:
                known_centers.append(known_features[mask].mean(axis=0))
                
        # 估计未知类数量
        self.estimate_cluster_num(features)
        
        # 执行半监督K-means
        self.kmeans = KMeans(n_clusters=self.total_clusters, init=np.vstack([
            np.array(known_centers),
            self.initialize_unknown_centers(features)
        ]))
        
        self.kmeans.fit(features)
        
    def estimate_cluster_num(self, features):
        # 论文中的自适应估计方法
        candidate_n = [self.n_known + i for i in range(1, 6)]
        best_score = -np.inf
        best_n = self.n_known + 1
        
        for n in candidate_n:
            kmeans = KMeans(n_clusters=n).fit(features)
            score = silhouette_score(features, kmeans.labels_)
            if score > best_score:
                best_score = score
                best_n = n
                
        self.total_clusters = best_n
        
    def initialize_unknown_centers(self, features):
        # 基于特征空间的密度初始化
        distances = np.linalg.norm(features - features.mean(axis=0), axis=1)
        return features[np.argsort(distances)[-self.total_clusters:]]

# 训练流程
def train_gcd(model, dataloader, epochs=50):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    loss_fn = GCDLoss()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        
        for batch in dataloader:
            images, labels, is_labeled = batch
            images, labels, is_labeled = images.to(device), labels.to(device), is_labeled.to(device)
            features = model(images)
            
            loss = loss_fn(features, labels, is_labeled)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # 每个epoch后执行聚类和评估
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(dataloader):.4f}")
        evaluate_clustering(model, dataloader, device)
        
def evaluate_clustering(model, dataloader, device):
    model.eval()
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _ in dataloader:
            images, labels = images.to(device), labels.to(device)
            features = model(images)
            all_features.append(features.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            
    features = np.concatenate(all_features)
    labels = np.concatenate(all_labels)
    
    # 执行半监督聚类
    cluster_model = SemiSupervisedKMeans(n_known_classes=num_known_classes)
    cluster_model.fit(features, labels)
    
    # 计算聚类指标
    known_mask = (labels != -1)
    ari = adjusted_rand_score(labels[known_mask], cluster_model.kmeans.labels_[known_mask])
    print(f"Adjusted Rand Index (Known Classes): {ari:.4f}")

# 获取类别数量
num_known_classes = len(datasets.ImageFolder(os.path.join(data_root, 'train')).classes)
print(f"Number of known classes: {num_known_classes}")

# 开始训练
train_gcd(model, train_loader)





In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, silhouette_score
from scipy.optimize import linear_sum_assignment
from PIL import Image
import os
from transformers import AutoImageProcessor, AutoModel

# 改进的数据集类（支持半监督设置）
class SemiSupervisedPlantDoc(Dataset):
    def __init__(self, root_dir, labeled_ratio=0.1, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.labeled_samples = []
        self.unlabeled_samples = []
        
        # 获取所有类别
        class_to_idx = {}
        for idx, cls in enumerate(sorted(os.listdir(os.path.join(root_dir, 'train')))):
            class_to_idx[cls] = idx
        
        # 解析标记数据
        for cls in class_to_idx:
            class_dir = os.path.join(root_dir, 'train', cls)
            images = [f for f in os.listdir(class_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
            split_idx = int(len(images) * labeled_ratio)
            
            for i, img_name in enumerate(images):
                img_path = os.path.join(class_dir, img_name)
                label = class_to_idx[cls]
                
                if i < split_idx:  # 标记数据
                    self.labeled_samples.append((img_path, label))
                else:  # 未标记数据
                    self.unlabeled_samples.append((img_path, -1))  # -1表示未标记
                    
        # 合并标记和未标记数据
        self.samples = self.labeled_samples + self.unlabeled_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        is_labeled = (label != -1)
        return image, label, is_labeled

# 数据预处理配置（移除归一化）
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor()  # 只转换到[0,1]范围
])

# 加载DINOv2模型（修正预处理）
class DINOv2Wrapper(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.processor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.projection = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
    def forward(self, x):
        inputs = self.processor(images=x, return_tensors="pt").pixel_values.to(x.device)
        features = self.model(pixel_values=inputs).last_hidden_state.mean(dim=1)
        return self.projection(features)

# 对比损失函数（保持不变）
class GCDLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cross_entropy = nn.CrossEntropyLoss()
        
    def supervised_contrastive(self, features, labels):
        labels = labels.unsqueeze(1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        
        logits = torch.matmul(features, features.T) / self.temperature
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()
        
        exp_logits = torch.exp(logits)
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        
        loss = - (mask * log_prob).sum(1) / mask.sum(1)
        return loss.mean()
    
    def forward(self, features, labels, is_labeled):
        labeled_mask = torch.where(is_labeled)[0]
        sup_loss = self.supervised_contrastive(features[labeled_mask], labels[labeled_mask]) if len(labeled_mask) > 1 else 0.0
        
        logits = torch.matmul(features, features.T) / self.temperature
        labels = torch.arange(logits.size(0)).to(logits.device)
        nce_loss = (self.cross_entropy(logits, labels) + self.cross_entropy(logits.T, labels)) / 2
        
        return 0.5*sup_loss + 0.5*nce_loss

# 半监督聚类模块（保持不变）
class SemiSupervisedKMeans:
    def __init__(self, n_known_classes):
        self.n_known = n_known_classes
        self.kmeans = KMeans()
        
    def fit(self, features, labels=None):
        known_features = features[labels != -1]
        known_labels = labels[labels != -1]
        known_centers = [known_features[known_labels == c].mean(axis=0) for c in range(self.n_known) if (known_labels == c).sum() > 0]
        
        self.estimate_cluster_num(features)
        init_centers = np.array(known_centers)
        
        # 计算需要额外添加的中心数量
        additional_centers_needed = self.total_clusters - len(init_centers)
        if additional_centers_needed > 0:
            sorted_indices = np.argsort(np.linalg.norm(features - features.mean(axis=0), axis=1))[::-1]
            additional_centers = features[sorted_indices[:additional_centers_needed]]
            init_centers = np.vstack([init_centers, additional_centers])
        
        self.kmeans = KMeans(n_clusters=self.total_clusters, init=init_centers)
        self.kmeans.fit(features)
        
    def estimate_cluster_num(self, features):
        candidate_n = [self.n_known + i for i in range(1, 6)]
        self.total_clusters = max([(n, silhouette_score(features, KMeans(n).fit(features).labels_)) for n in candidate_n], 
                                 key=lambda x: x[1])[0]

# 训练流程（添加设备处理）
def train_gcd(model, dataloader, epochs=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    loss_fn = GCDLoss().to(device)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        for images, labels, is_labeled in dataloader:
            images, labels, is_labeled = images.to(device), labels.to(device), is_labeled.to(device)
            
            optimizer.zero_grad()
            features = model(images)
            loss = loss_fn(features, labels, is_labeled)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(dataloader):.4f}")
        evaluate_clustering(model, dataloader, device)

def evaluate_clustering(model, dataloader, device):
    model.eval()
    all_features, all_labels = [], []
    with torch.no_grad():
        for images, labels, _ in dataloader:
            features = model(images.to(device)).cpu().numpy()
            all_features.append(features)
            all_labels.append(labels.numpy())
    
    features = np.concatenate(all_features)
    labels = np.concatenate(all_labels)
    cluster_model = SemiSupervisedKMeans(num_known_classes)
    cluster_model.fit(features, labels)
    
    known_mask = (labels != -1)
    ari = adjusted_rand_score(labels[known_mask], cluster_model.kmeans.labels_[known_mask])
    print(f"Adjusted Rand Index: {ari:.4f}")

# 初始化配置
data_root = 'E:\data1\PlantDisease'  # 替换为实际数据路径
num_known_classes = len(os.listdir(os.path.join(data_root, 'train')))

# 创建数据集和数据加载器
full_dataset = SemiSupervisedPlantDoc(
    root_dir=data_root,
    labeled_ratio=0.1,
    transform=train_transform
)
train_loader = DataLoader(full_dataset, batch_size=64, shuffle=True, num_workers=4)

# 初始化模型
model = DINOv2Wrapper("E:\models\\facebook_Dinov2")  # 使用正确的模型名称

# 开始训练
train_gcd(model, train_loader)





  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics import adjusted_rand_score, silhouette_score, normalized_mutual_info_score, confusion_matrix
from scipy.optimize import linear_sum_assignment
from PIL import Image
import os
from sklearn.manifold import TSNE
import seaborn as sns
from transformers import AutoImageProcessor, AutoModel

# 改进的数据集类
class SemiSupervisedPlantDoc(Dataset):
    def __init__(self, root_dir, labeled_ratio=0.1):
        self.root_dir = root_dir
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor()
        ])
        
        class_to_idx = {cls:i for i, cls in enumerate(sorted(os.listdir(os.path.join(root_dir, 'train'))))}
        self.samples = []
        
        # 构建半监督数据集
        for cls, idx in class_to_idx.items():
            cls_dir = os.path.join(root_dir, 'train', cls)
            images = [os.path.join(cls_dir, f) for f in os.listdir(cls_dir) if f.endswith(('.jpg','png','jpeg'))]
            split = int(len(images)*labeled_ratio)
            
            for i, img_path in enumerate(images):
                if i < split:  # 标记样本
                    self.samples.append((img_path, idx))
                else:  # 未标记样本
                    self.samples.append((img_path, -1))

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert('RGB')
        return self.transform(img), label, (label != -1)

# 增强的DINOv2模型
class EnhancedDINOv2(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.processor = AutoImageProcessor.from_pretrained(model_name)
        self.backbone = AutoModel.from_pretrained(model_name)
        self.projection = nn.Sequential(
            nn.Linear(768, 512),
            nn.BatchNorm1d(512),
            nn.GELU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.GELU(),
            nn.Linear(256, 128)
        )
        
    def forward(self, x):
        with torch.no_grad():  # 冻结主干网络
            inputs = self.processor(x, return_tensors="pt").pixel_values.to(x.device)
            features = self.backbone(pixel_values=inputs).last_hidden_state.mean(dim=1)
        return self.projection(features)

# 改进的对比损失
class EnhancedGCDLoss(nn.Module):
    def __init__(self, temp=0.1):
        super().__init__()
        self.temp = temp
        self.ce = nn.CrossEntropyLoss(ignore_index=-1)
        
    def supervised_loss(self, f, labels):
        mask = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0))
        logits = torch.matmul(f, f.T) / self.temp
        logits = logits - torch.max(logits, dim=1, keepdim=True)[0].detach()
        return -torch.mean(mask * torch.log_softmax(logits, dim=1))
    
    def forward(self, features, labels, is_labeled):
        labeled = features[is_labeled]
        sup_loss = self.supervised_loss(labeled, labels[is_labeled]) if labeled.size(0) > 1 else 0
        
        # 自监督对比
        logits = torch.matmul(features, features.T) / self.temp
        ssl_loss = self.ce(logits, labels)
        
        return 0.7*sup_loss + 0.3*ssl_loss

# 改进的聚类模块
class EnhancedClusterer:
    def __init__(self, n_known):
        self.n_known = n_known
        
    def fit_predict(self, features, labels):
        known_features = features[labels != -1]
        known_labels = labels[labels != -1]
        
        # 计算已知类别中心
        unique_known_labels = np.unique(known_labels)
        if len(unique_known_labels) == 0:
            print("No known labels found.")
            return np.full_like(labels, -1)
        
        known_centers = {}
        for c in unique_known_labels:
            known_centers[c] = known_features[known_labels == c].mean(axis=0)
        
        # 使用DBSCAN来发现新类别
        dbscan = DBSCAN(eps=0.5, min_samples=5)
        db_labels = dbscan.fit_predict(features)
        
        # 合并已知标签和DBSCAN结果
        unique_db_labels = set(db_labels)
        new_label_counter = max(unique_known_labels) + 1 if unique_known_labels.size > 0 else 0
        
        final_labels = np.full_like(db_labels, -1)
        for lbl in unique_db_labels:
            if lbl == -1:  # 噪声点
                continue
            
            cluster_indices = np.where(db_labels == lbl)[0]
            cluster_features = features[cluster_indices]
            
            if len(cluster_features) < 5:  # 过滤掉小簇
                continue
            
            closest_center_idx = min(known_centers.keys(), key=lambda k: np.linalg.norm(cluster_features.mean(axis=0) - known_centers[k]), default=None)
            if closest_center_idx is None:
                # 如果没有找到最近的已知中心，则认为这是一个新类别
                final_labels[cluster_indices] = new_label_counter
                new_label_counter += 1
            elif not self._is_close_to_any_known(cluster_features.mean(axis=0), list(known_centers.values()), threshold=0.5):
                final_labels[cluster_indices] = new_label_counter
                new_label_counter += 1
            else:
                final_labels[cluster_indices] = closest_center_idx
        
        # 确保聚类数量至少为已知类别数
        if len(np.unique(final_labels)) < self.n_known:
            km = KMeans(n_clusters=self.n_known)
            km_labels = km.fit_predict(features)
            final_labels = km_labels
        
        return final_labels
    
    def _estimate_cluster_num(self, features):
        scores = []
        for n in range(self.n_known, self.n_known+5):
            km = KMeans(n_clusters=n)
            labels = km.fit_predict(features)
            scores.append(silhouette_score(features, labels))
        return self.n_known + np.argmax(scores)
    
    def _is_close_to_any_known(self, feature, centers, threshold=0.5):
        distances = [np.linalg.norm(feature - center) for center in centers]
        return any(d < threshold for d in distances)

# 训练流程（含可视化）
def train_with_visualization(model, dataloader, epochs=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    loss_fn = EnhancedGCDLoss().to(device)
    
    history = {'loss': [], 'ari': [], 'nmi': []}
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for imgs, labels, mask in dataloader:
            imgs, labels, mask = imgs.to(device), labels.to(device), mask.to(device)
            
            optimizer.zero_grad()
            features = model(imgs)
            loss = loss_fn(features, labels, mask)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        scheduler.step()
        avg_loss = total_loss/len(dataloader)
        history['loss'].append(avg_loss)
        
        # 评估与可视化
        if (epoch+1) % 5 == 0:
            ari, nmi = evaluate_and_visualize(model, dataloader, device, epoch+1)
            history['ari'].append(ari)
            history['nmi'].append(nmi)
            print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, ARI={ari:.4f}, NMI={nmi:.4f}")
    
    # 绘制训练曲线
    plt.figure(figsize=(15,5))
    plt.subplot(131)
    plt.plot(history['loss'])
    plt.title('Training Loss')
    
    plt.subplot(132)
    plt.plot(history['ari'])
    plt.title('Adjusted Rand Index')
    
    plt.subplot(133)
    plt.plot(history['nmi'])
    plt.title('Normalized Mutual Info')
    plt.tight_layout()
    plt.show()

def evaluate_and_visualize(model, dataloader, device, epoch):
    model.eval()
    features, labels = [], []
    
    with torch.no_grad():
        for imgs, lbls, _ in dataloader:
            feats = model(imgs.to(device)).cpu().numpy()
            features.append(feats)
            labels.append(lbls.numpy())
    
    features = np.concatenate(features)
    labels = np.concatenate(labels)
    clusterer = EnhancedClusterer(n_known=num_classes)
    pred_labels = clusterer.fit_predict(features, labels)
    
    # 计算指标
    known_mask = labels != -1
    ari = adjusted_rand_score(labels[known_mask], pred_labels[known_mask]) if np.sum(known_mask) > 0 else float('nan')
    nmi = normalized_mutual_info_score(labels[known_mask], pred_labels[known_mask]) if np.sum(known_mask) > 0 else float('nan')
    
    # 可视化特征分布
    tsne = TSNE(n_components=2, random_state=42)
    tsne_feats = tsne.fit_transform(features)
    
    plt.figure(figsize=(12,5))
    plt.subplot(121)
    plt.scatter(tsne_feats[:,0], tsne_feats[:,1], c=labels, cmap='tab20', alpha=0.6)
    plt.title(f'True Labels (Epoch {epoch})')
    
    plt.subplot(122)
    plt.scatter(tsne_feats[:,0], tsne_feats[:,1], c=pred_labels, cmap='tab20', alpha=0.6)
    plt.title(f'Predicted Clusters (Epoch {epoch})')
    plt.show()
    
    # 绘制混淆矩阵
    if np.sum(known_mask) > 0:
        cm = confusion_matrix(labels[known_mask], pred_labels[known_mask])
        row_ind, col_ind = linear_sum_assignment(-cm)
        cm = cm[row_ind][:, col_ind]
        
        plt.figure(figsize=(10,8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.xlabel('Predicted Clusters')
        plt.ylabel('True Classes')
        plt.title(f'Confusion Matrix (Epoch {epoch})')
        plt.show()
    
    return ari, nmi

# 初始化配置
data_root = 'E:\data1\PlantDisease'
num_classes = 38  # 修改为38类

# 创建数据集和加载器
dataset = SemiSupervisedPlantDoc(data_root, labeled_ratio=0.1)
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)

# 初始化模型
model = EnhancedDINOv2("E:\models\\facebook_Dinov2")

# 开始训练
train_with_visualization(model, loader, epochs=50)





  from .autonotebook import tqdm as notebook_tqdm


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from PIL import Image
import os
from sklearn.manifold import TSNE
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm  # 添加进度条库

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 自定义数据集类
class PlantDocDataset(Dataset):
    def __init__(self, root_dir, txt_path, transform=None, train=True, train_ratio=0.8, random_seed=42):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # 解析txt文件
        with open(txt_path, 'r') as f:
            lines = f.readlines()

        for line in lines:
            line = line.strip()
            if not line:
                continue
            parts = line.split('=')
            if len(parts) < 3:
                continue
            img_rel_path, label_str, _ = parts[0], parts[1], parts[2]
            img_full_path = os.path.join(root_dir, 'images', img_rel_path.replace('/', os.path.sep))
            if not os.path.exists(img_full_path):
                continue
            label = int(label_str)
            self.samples.append((img_full_path, label))

        # 随机分割数据集
        num_samples = len(self.samples)
        indices = list(range(num_samples))
        np.random.seed(random_seed)
        np.random.shuffle(indices)
        split_idx = int(train_ratio * num_samples)
        self.indices = indices[:split_idx] if train else indices[split_idx:]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        img_path, label = self.samples[actual_idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 初始化配置
root_dir = r'E:/data1/plantdoc'
txt_path = r'E:/data1/plantdoc/trainval.txt'

# 已知种类数
num_classes = 89

# 创建数据集和数据加载器
train_dataset = PlantDocDataset(
    root_dir=root_dir,
    txt_path=txt_path,
    transform=transform,
    train=True
)

test_dataset = PlantDocDataset(
    root_dir=root_dir,
    txt_path=txt_path,
    transform=transform,
    train=False
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# 模型定义
def load_efficientnet_b4(num_classes):
    model = models.efficientnet_b4(weights='IMAGENET1K_V1')
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    return model

def load_convnext_tiny(num_classes):
    model = models.convnext_tiny(weights='IMAGENET1K_V1')
    num_ftrs = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
    return model

class FusionNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.efficientnet = load_efficientnet_b4(num_classes)
        self.convnext = load_convnext_tiny(num_classes)
        self.fc = nn.Linear(2*num_classes, num_classes)

    def forward(self, x):
        eff_out = self.efficientnet(x)
        conv_out = self.convnext(x)
        return self.fc(torch.cat([eff_out, conv_out], dim=1))

# 知识蒸馏损失
class DistillationLoss(nn.Module):
    def __init__(self, T=3, alpha=0.5):
        super().__init__()
        self.T = T
        self.alpha = alpha
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_out, teacher_out, labels):
        soft_loss = self.kl_loss(
            nn.functional.log_softmax(student_out/self.T, dim=1),
            nn.functional.softmax(teacher_out/self.T, dim=1)
        ) * (self.alpha * self.T**2)
        hard_loss = self.ce_loss(student_out, labels) * (1. - self.alpha)
        return soft_loss + hard_loss

# 增强聚类模块
class EnhancedClusterer:
    def __init__(self, n_known):
        self.n_known = n_known
        
    def fit_predict(self, features, labels):
        known_features = features[labels != -1]
        known_labels = labels[labels != -1]
        
        # 计算已知类别中心
        unique_known_labels = np.unique(known_labels)
        if len(unique_known_labels) == 0:
            return np.full_like(labels, -1), np.array([])
        
        known_centers = {c: known_features[known_labels == c].mean(0) for c in unique_known_labels}
        
        # 使用DBSCAN发现新类别
        dbscan = DBSCAN(eps=0.5, min_samples=5)
        db_labels = dbscan.fit_predict(features)
        
        # 合并结果
        final_labels = np.full_like(db_labels, -1)
        new_label = max(unique_known_labels) + 1 if unique_known_labels.size > 0 else 0
        new_labels = []
        
        for lbl in set(db_labels):
            if lbl == -1: continue
            cluster_idx = np.where(db_labels == lbl)[0]
            cluster_feats = features[cluster_idx]
            
            if len(cluster_feats) < 5: continue
            
            # 寻找最近已知类别
            closest = min(known_centers.keys(), 
                         key=lambda k: np.linalg.norm(cluster_feats.mean(0)-known_centers[k]))
            
            if closest is None or not self._is_close(cluster_feats.mean(0), 
                                                   list(known_centers.values()), 0.5):
                final_labels[cluster_idx] = new_label
                new_labels.append(new_label)
                new_label += 1
            else:
                final_labels[cluster_idx] = closest
                
        # 确保最少聚类数
        if len(np.unique(final_labels)) < self.n_known:
            final_labels = KMeans(n_clusters=self.n_known).fit_predict(features)
            unique_final_labels = np.unique(final_labels)
            new_labels = unique_final_labels[unique_final_labels >= self.n_known]
        
        return final_labels, np.array(new_labels)
    
    def _is_close(self, feature, centers, threshold):
        return any(np.linalg.norm(feature - c) < threshold for c in centers)

# 训练流程
def train_with_visualization(model, dataloader, epochs=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    loss_fn = DistillationLoss().to(device)
    
    history = {'loss': [], 'ari': [], 'nmi': []}
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        # 添加进度条
        for imgs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            imgs, labels = imgs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = loss_fn(outputs, outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        scheduler.step()
        avg_loss = total_loss/len(dataloader)
        history['loss'].append(avg_loss)
        
        # 每5轮评估
        if (epoch+1) % 5 == 0:
            ari, nmi, num_predicted_clusters, num_true_classes = evaluate(model, dataloader, device)
            history['ari'].append(ari)
            history['nmi'].append(nmi)
            print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, ARI={ari:.4f}, NMI={nmi:.4f}, "
                  f"Predicted Clusters={num_predicted_clusters}, True Classes={num_true_classes}")
    
    # 绘制训练曲线
    plt.figure(figsize=(15,5))
    plt.subplot(131)
    plt.plot(history['loss'], label='Loss')
    plt.subplot(132)
    plt.plot(history['ari'], label='ARI')
    plt.subplot(133)
    plt.plot(history['nmi'], label='NMI')
    plt.tight_layout()
    plt.show()

def evaluate(model, dataloader, device):
    model.eval()
    features, labels = [], []
    
    with torch.no_grad():
        for imgs, lbls in dataloader:
            feats = model(imgs.to(device)).cpu().numpy()
            features.append(feats)
            labels.append(lbls.numpy())
    
    features = np.concatenate(features)
    labels = np.concatenate(labels)
    
    clusterer = EnhancedClusterer(n_known=num_classes)
    pred_labels, new_labels = clusterer.fit_predict(features, labels)
    
    # 计算指标
    ari = adjusted_rand_score(labels, pred_labels)
    nmi = normalized_mutual_info_score(labels, pred_labels)
    
    # 计算聚类结果的数量和真实类别的数量
    num_predicted_clusters = len(np.unique(pred_labels))
    num_true_classes = len(np.unique(labels))
    
    # 可视化
    visualize_features(features, labels, pred_labels, new_labels)
    return ari, nmi, num_predicted_clusters, num_true_classes

def visualize_features(features, true_labels, pred_labels, new_labels):
    tsne = TSNE(n_components=2, random_state=42)
    embed = tsne.fit_transform(features)
    
    plt.figure(figsize=(12,5))
    plt.subplot(121)
    plt.scatter(embed[:,0], embed[:,1], c=true_labels, cmap='tab20', alpha=0.6)
    plt.title('True Labels')
    
    plt.subplot(122)
    # 使用不同的颜色或标记来区分已知类别和新发现的类别
    unique_pred_labels = np.unique(pred_labels)
    for lbl in unique_pred_labels:
        if lbl in new_labels:
            plt.scatter(embed[pred_labels == lbl, 0], embed[pred_labels == lbl, 1], 
                        label=f'New {lbl}', alpha=0.6, marker='x')
        else:
            plt.scatter(embed[pred_labels == lbl, 0], embed[pred_labels == lbl, 1], 
                        label=f'Known {lbl}', alpha=0.6, marker='o')
    plt.title('Predicted Clusters')
    plt.legend()
    plt.show()

# 初始化模型
model = FusionNet(num_classes=num_classes)

# 开始训练
train_with_visualization(model, train_loader, epochs=100)

Epoch 1/100: 100%|██████████| 928/928 [07:05<00:00,  2.18it/s]
Epoch 2/100: 100%|██████████| 928/928 [06:56<00:00,  2.23it/s]
Epoch 3/100: 100%|██████████| 928/928 [06:56<00:00,  2.23it/s]
Epoch 4/100: 100%|██████████| 928/928 [06:53<00:00,  2.24it/s]
Epoch 5/100: 100%|██████████| 928/928 [06:56<00:00,  2.23it/s]
[WinError 2] 系统找不到指定的文件。
  File "d:\Users\songyanghui\anaconda3\envs\mvpdr\lib\site-packages\joblib\externals\loky\backend\context.py", line 257, in _count_physical_cores
    cpu_info = subprocess.run(
  File "d:\Users\songyanghui\anaconda3\envs\mvpdr\lib\subprocess.py", line 493, in run
    with Popen(*popenargs, **kwargs) as process:
  File "d:\Users\songyanghui\anaconda3\envs\mvpdr\lib\subprocess.py", line 858, in __init__
    self._execute_child(args, executable, preexec_fn, close_fds,
  File "d:\Users\songyanghui\anaconda3\envs\mvpdr\lib\subprocess.py", line 1327, in _execute_child
    hp, ht, pid, tid = _winapi.CreateProcess(executable, args,
  super()._check_params_vs_i

: 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from PIL import Image
import os
from sklearn.manifold import TSNE
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm  # 添加进度条库

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 自定义数据集类
class PlantDocDataset(Dataset):
    def __init__(self, root_dir, txt_path, transform=None, train=True, train_ratio=0.8, random_seed=42):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # 解析txt文件
        with open(txt_path, 'r') as f:
            lines = f.readlines()

        for line in lines:
            line = line.strip()
            if not line:
                continue
            parts = line.split('=')
            if len(parts) < 3:
                continue
            img_rel_path, label_str, _ = parts[0], parts[1], parts[2]
            img_full_path = os.path.join(root_dir, 'images', img_rel_path.replace('/', os.path.sep))
            if not os.path.exists(img_full_path):
                continue
            label = int(label_str)
            self.samples.append((img_full_path, label))

        # 随机分割数据集
        num_samples = len(self.samples)
        indices = list(range(num_samples))
        np.random.seed(random_seed)
        np.random.shuffle(indices)
        split_idx = int(train_ratio * num_samples)
        self.indices = indices[:split_idx] if train else indices[split_idx:]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        img_path, label = self.samples[actual_idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 初始化配置
root_dir = r'E:/data1/plantdoc'
txt_path = r'E:/data1/plantdoc/trainval.txt'

# 已知种类数
num_classes = 89

# 创建数据集和数据加载器
train_dataset = PlantDocDataset(
    root_dir=root_dir,
    txt_path=txt_path,
    transform=transform,
    train=True
)

test_dataset = PlantDocDataset(
    root_dir=root_dir,
    txt_path=txt_path,
    transform=transform,
    train=False
)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# 模型定义
def load_efficientnet_b4(num_classes):
    model = models.efficientnet_b4(weights='IMAGENET1K_V1')
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    return model

def load_convnext_tiny(num_classes):
    model = models.convnext_tiny(weights='IMAGENET1K_V1')
    num_ftrs = model.classifier[-1].in_features
    model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
    return model

class FusionNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.efficientnet = load_efficientnet_b4(num_classes)
        self.convnext = load_convnext_tiny(num_classes)
        self.fc = nn.Linear(2*num_classes, num_classes)

    def forward(self, x):
        eff_out = self.efficientnet(x)
        conv_out = self.convnext(x)
        return self.fc(torch.cat([eff_out, conv_out], dim=1))

# 知识蒸馏损失
class DistillationLoss(nn.Module):
    def __init__(self, T=3, alpha=0.5):
        super().__init__()
        self.T = T
        self.alpha = alpha
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_out, teacher_out, labels):
        soft_loss = self.kl_loss(
            nn.functional.log_softmax(student_out/self.T, dim=1),
            nn.functional.softmax(teacher_out/self.T, dim=1)
        ) * (self.alpha * self.T**2)
        hard_loss = self.ce_loss(student_out, labels) * (1. - self.alpha)
        return soft_loss + hard_loss

# 增强聚类模块
class EnhancedClusterer:
    def __init__(self, n_known):
        self.n_known = n_known
        
    def fit_predict(self, features, labels):
        known_features = features[labels != -1]
        known_labels = labels[labels != -1]
        
        # 计算已知类别中心
        unique_known_labels = np.unique(known_labels)
        if len(unique_known_labels) == 0:
            return np.full_like(labels, -1)
        
        known_centers = {c: known_features[known_labels == c].mean(0) for c in unique_known_labels}
        
        # 使用DBSCAN发现新类别
        dbscan = DBSCAN(eps=0.5, min_samples=5)
        db_labels = dbscan.fit_predict(features)
        
        # 合并结果
        final_labels = np.full_like(db_labels, -1)
        new_label = max(unique_known_labels) + 1 if unique_known_labels.size > 0 else 0
        
        for lbl in set(db_labels):
            if lbl == -1: continue
            cluster_idx = np.where(db_labels == lbl)[0]
            cluster_feats = features[cluster_idx]
            
            if len(cluster_feats) < 5: continue
            
            # 寻找最近已知类别
            closest = min(known_centers.keys(), 
                         key=lambda k: np.linalg.norm(cluster_feats.mean(0)-known_centers[k]))
            
            if closest is None or not self._is_close(cluster_feats.mean(0), 
                                                   list(known_centers.values()), 0.5):
                final_labels[cluster_idx] = new_label
                new_label += 1
            else:
                final_labels[cluster_idx] = closest
                
        # 确保最少聚类数
        if len(np.unique(final_labels)) < self.n_known:
            final_labels = KMeans(n_clusters=self.n_known).fit_predict(features)
            
        return final_labels
    
    def _is_close(self, feature, centers, threshold):
        return any(np.linalg.norm(feature - c) < threshold for c in centers)

# 训练流程
def train_with_visualization(model, dataloader, epochs=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    loss_fn = DistillationLoss().to(device)
    
    history = {'loss': [], 'ari': [], 'nmi': []}
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        # 添加进度条
        for imgs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            imgs, labels = imgs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = loss_fn(outputs, outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        scheduler.step()
        avg_loss = total_loss/len(dataloader)
        history['loss'].append(avg_loss)
        
        # 每5轮评估
        if (epoch+1) % 1 == 0:
            ari, nmi = evaluate(model, dataloader, device)
            history['ari'].append(ari)
            history['nmi'].append(nmi)
            print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, ARI={ari:.4f}, NMI={nmi:.4f}")
    
    # 绘制训练曲线
    plt.figure(figsize=(15,5))
    plt.subplot(131)
    plt.plot(history['loss'], label='Loss')
    plt.subplot(132)
    plt.plot(history['ari'], label='ARI')
    plt.subplot(133)
    plt.plot(history['nmi'], label='NMI')
    plt.tight_layout()
    plt.show()

def evaluate(model, dataloader, device):
    model.eval()
    features, labels = [], []
    
    with torch.no_grad():
        for imgs, lbls in dataloader:
            feats = model(imgs.to(device)).cpu().numpy()
            features.append(feats)
            labels.append(lbls.numpy())
    
    features = np.concatenate(features)
    labels = np.concatenate(labels)
    
    clusterer = EnhancedClusterer(n_known=num_classes)
    pred_labels = clusterer.fit_predict(features, labels)
    
    # 计算指标
    ari = adjusted_rand_score(labels, pred_labels)
    nmi = normalized_mutual_info_score(labels, pred_labels)
    
    # 可视化
    visualize_features(features, labels, pred_labels)
    return ari, nmi

def visualize_features(features, true_labels, pred_labels):
    tsne = TSNE(n_components=2, random_state=42)
    embed = tsne.fit_transform(features)
    
    plt.figure(figsize=(12,5))
    plt.subplot(121)
    plt.scatter(embed[:,0], embed[:,1], c=true_labels, cmap='tab20', alpha=0.6)
    plt.title('True Labels')
    
    plt.subplot(122)
    plt.scatter(embed[:,0], embed[:,1], c=pred_labels, cmap='tab20', alpha=0.6)
    plt.title('Predicted Clusters')
    plt.show()

# 初始化模型
model = FusionNet(num_classes=num_classes)

# 开始训练
train_with_visualization(model, train_loader, epochs=200)

Epoch 1/200:   5%|▍         | 46/928 [00:20<06:38,  2.22it/s]


KeyboardInterrupt: 

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
from PIL import Image
import seaborn as sns
from collections import defaultdict
from tqdm import tqdm  # 添加进度条库
import logging
import psutil  # 用于监控内存使用

# 设置环境变量以限制 OpenBLAS 线程数
os.environ["OPENBLAS_NUM_THREADS"] = "24"  # 充分利用所有核心
os.environ["MKL_NUM_THREADS"] = "24"

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 配置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 自定义数据集类
class PlantDocDataset(Dataset):
    def __init__(self, root_dir, txt_path, transform=None, train=True, train_ratio=0.8, random_seed=42):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # 解析txt文件
        with open(txt_path, 'r') as f:
            lines = f.readlines()

        for line in lines:
            line = line.strip()
            if not line:
                continue
            parts = line.split('=')
            if len(parts) < 3:
                continue
            img_rel_path, label_str, _ = parts[0], parts[1], parts[2]
            img_full_path = os.path.join(root_dir, 'images', img_rel_path.replace('/', os.path.sep))
            if not os.path.exists(img_full_path):
                continue
            label = int(label_str)
            self.samples.append((img_full_path, label))

        # 随机分割数据集
        num_samples = len(self.samples)
        indices = list(range(num_samples))
        np.random.seed(random_seed)
        np.random.shuffle(indices)
        split_idx = int(train_ratio * num_samples)
        self.indices = indices[:split_idx] if train else indices[split_idx:]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        actual_idx = self.indices[idx]
        img_path, label = self.samples[actual_idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 初始化配置
root_dir = r'E:/data1/plantdoc'
txt_path = r'E:/data1/plantdoc/trainval.txt'

# 已知种类数
num_classes = 89

# 创建数据集和数据加载器
train_dataset = PlantDocDataset(
    root_dir=root_dir,
    txt_path=txt_path,
    transform=transform,
    train=True
)

test_dataset = PlantDocDataset(
    root_dir=root_dir,
    txt_path=txt_path,
    transform=transform,
    train=False
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=24)  # 增加 num_workers
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=24)    # 增加 num_workers

# 简化模型
class SimpleNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.model = models.efficientnet_b0(weights='IMAGENET1K_V1')  # 使用更小的 EfficientNet
        num_ftrs = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        return self.model(x)

# 知识蒸馏损失
class DistillationLoss(nn.Module):
    def __init__(self, T=3, alpha=0.5):
        super().__init__()
        self.T = T
        self.alpha = alpha
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, student_out, teacher_out, labels):
        soft_loss = self.kl_loss(
            nn.functional.log_softmax(student_out/self.T, dim=1),
            nn.functional.softmax(teacher_out/self.T, dim=1)
        ) * (self.alpha * self.T**2)
        hard_loss = self.ce_loss(student_out, labels) * (1. - self.alpha)
        return soft_loss + hard_loss

# 增强聚类模块
class EnhancedClusterer:
    def __init__(self, n_known):
        self.n_known = n_known
        
    def fit_predict(self, features, labels):
        known_features = features[labels != -1]
        known_labels = labels[labels != -1]
        
        # 计算已知类别中心
        unique_known_labels = np.unique(known_labels)
        if len(unique_known_labels) == 0:
            return np.full_like(labels, -1), np.array([])
        
        known_centers = {c: known_features[known_labels == c].mean(0) for c in unique_known_labels}
        
        # 使用DBSCAN发现新类别
        dbscan = DBSCAN(eps=0.5, min_samples=5)
        db_labels = dbscan.fit_predict(features)
        
        # 合并结果
        final_labels = np.full_like(db_labels, -1)
        new_label = max(unique_known_labels) + 1 if unique_known_labels.size > 0 else 0
        new_labels = []
        
        for lbl in set(db_labels):
            if lbl == -1: continue
            cluster_idx = np.where(db_labels == lbl)[0]
            cluster_feats = features[cluster_idx]
            
            if len(cluster_feats) < 5: continue
            
            # 寻找最近已知类别
            closest = min(known_centers.keys(), 
                         key=lambda k: np.linalg.norm(cluster_feats.mean(0)-known_centers[k]))
            
            if closest is None or not self._is_close(cluster_feats.mean(0), 
                                                   list(known_centers.values()), 0.5):
                final_labels[cluster_idx] = new_label
                new_labels.append(new_label)
                new_label += 1
            else:
                final_labels[cluster_idx] = closest
                
        # 确保最少聚类数
        if len(np.unique(final_labels)) < self.n_known:
            final_labels = KMeans(n_clusters=self.n_known).fit_predict(features)
            unique_final_labels = np.unique(final_labels)
            new_labels = unique_final_labels[unique_final_labels >= self.n_known]
        
        return final_labels, np.array(new_labels)
    
    def _is_close(self, feature, centers, threshold):
        return any(np.linalg.norm(feature - c) < threshold for c in centers)

# 训练流程
def train_with_visualization(model, dataloader, epochs=100):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
    loss_fn = DistillationLoss().to(device)
    
    history = {'loss': [], 'ari': [], 'nmi': []}
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        # 添加进度条
        for imgs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            imgs, labels = imgs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = loss_fn(outputs, outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        scheduler.step()
        avg_loss = total_loss/len(dataloader)
        history['loss'].append(avg_loss)
        
        # 每5轮评估
        if (epoch+1) % 1 == 0:
            ari, nmi, num_predicted_clusters, num_true_classes = evaluate(model, dataloader, device)
            history['ari'].append(ari)
            history['nmi'].append(nmi)
            logger.info(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, ARI={ari:.4f}, NMI={nmi:.4f}, "
                        f"Predicted Clusters={num_predicted_clusters}, True Classes={num_true_classes}")
        
        # 监控内存使用
        memory_usage = psutil.virtual_memory().percent
        logger.info(f"Memory Usage: {memory_usage}%")
    
    # 绘制训练曲线
    plt.figure(figsize=(15,5))
    plt.subplot(131)
    plt.plot(history['loss'], label='Loss')
    plt.subplot(132)
    plt.plot(history['ari'], label='ARI')
    plt.subplot(133)
    plt.plot(history['nmi'], label='NMI')
    plt.tight_layout()
    plt.show()

def evaluate(model, dataloader, device):
    model.eval()
    features, labels = [], []
    
    with torch.no_grad():
        for imgs, lbls in dataloader:
            feats = model(imgs.to(device)).cpu().numpy()
            features.append(feats)
            labels.append(lbls.numpy())
    
    features = np.concatenate(features)
    labels = np.concatenate(labels)
    
    clusterer = EnhancedClusterer(n_known=num_classes)
    pred_labels, new_labels = clusterer.fit_predict(features, labels)
    
    # 计算指标
    ari = adjusted_rand_score(labels, pred_labels)
    nmi = normalized_mutual_info_score(labels, pred_labels)
    
    # 计算聚类结果的数量和真实类别的数量
    num_predicted_clusters = len(np.unique(pred_labels))
    num_true_classes = len(np.unique(labels))
    
    # 可视化
    visualize_features(features, labels, pred_labels, new_labels)
    return ari, nmi, num_predicted_clusters, num_true_classes

def visualize_features(features, true_labels, pred_labels, new_labels):
    tsne = TSNE(n_components=2, random_state=42)
    embed = tsne.fit_transform(features)
    
    plt.figure(figsize=(12,8))
    plt.subplot(211)
    plt.scatter(embed[:,0], embed[:,1], c=true_labels, cmap='tab20', alpha=0.6)
    plt.title('True Labels')
    
    plt.subplot(212)
    # 使用不同的颜色或标记来区分已知类别和新发现的类别
    unique_pred_labels = np.unique(pred_labels)
    handles = []
    labels = []
    for lbl in unique_pred_labels:
        if lbl in new_labels:
            scatter = plt.scatter(embed[pred_labels == lbl, 0], embed[pred_labels == lbl, 1], 
                                  label=f'New {lbl}', alpha=0.6, marker='x')
        else:
            scatter = plt.scatter(embed[pred_labels == lbl, 0], embed[pred_labels == lbl, 1], 
                                  label=f'Known {lbl}', alpha=0.6, marker='o')
        handles.append(scatter)
        labels.append(scatter.get_label())
    
    plt.title('Predicted Clusters')
    plt.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=5)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

# 初始化模型
model = SimpleNet(num_classes=num_classes)

# 开始训练
try:
    train_with_visualization(model, train_loader, epochs=100)
except Exception as e:
    logger.error(f"An error occurred: {e}")
    import traceback
    traceback.print_exc()

Epoch 1/100:   0%|          | 0/1855 [00:00<?, ?it/s]