In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModel
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
from scipy.optimize import linear_sum_assignment
from torchvision import transforms  # 添加这一行
import os
from PIL import Image
import torch.nn as nn
import matplotlib.pyplot as plt
import warnings


# 改进的数据集类（支持半监督设置）
class SemiSupervisedPlantDoc(Dataset):
    def __init__(self, root_dir, txt_path, transform=None, labeled_ratio=0.1):
        self.root_dir = root_dir
        self.transform = transform
        self.labeled_samples = []
        self.unlabeled_samples = []
        
        # 解析原始数据
        with open(txt_path, 'r') as f:
            lines = [line.strip() for line in f.readlines() if line.strip()]
        
        # 随机划分标记/未标记数据
        np.random.shuffle(lines)
        split_idx = int(len(lines)*labeled_ratio)
        
        for i, line in enumerate(lines):
            parts = line.split('=')
            img_rel_path, label_str, _ = parts[0], parts[1], parts[2]
            img_path = os.path.join(root_dir, 'images', img_rel_path.replace('/', os.path.sep))
            
            if i < split_idx:  # 标记数据
                self.labeled_samples.append((img_path, int(label_str)))
            else:  # 未标记数据
                self.unlabeled_samples.append((img_path, -1))  # -1表示未标记

    def __len__(self):
        return len(self.labeled_samples) + len(self.unlabeled_samples)

    def __getitem__(self, idx):
        if idx < len(self.labeled_samples):
            img_path, label = self.labeled_samples[idx]
            is_labeled = True
        else:
            img_path, label = self.unlabeled_samples[idx - len(self.labeled_samples)]
            is_labeled = False
            
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        return image, label, is_labeled

# 对比学习增强策略
contrastive_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 初始化数据集
full_dataset = SemiSupervisedPlantDoc(
    root_dir=r'E:\data1\plantdoc',
    txt_path=r'E:\data1\plantdoc\trainval.txt',
    transform=contrastive_transform,
    labeled_ratio=0.1
)

# 数据加载器（混合标记和未标记数据）
train_loader = DataLoader(full_dataset, batch_size=64, shuffle=True, num_workers=4)

# 加载DINOv2模型
class DINOv2Wrapper(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_path)
        self.projection = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        
    def forward(self, x):
        features = self.model(x).last_hidden_state.mean(dim=1)
        return self.projection(features)

model = DINOv2Wrapper(r'E:\models\facebook_Dinov2')

# 对比损失函数
class GCDLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cross_entropy = nn.CrossEntropyLoss()
        
    def supervised_contrastive(self, features, labels):
        labels = labels.unsqueeze(1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        
        logits = torch.matmul(features, features.T) / self.temperature
        logits_max, _ = torch.max(logits, dim=1, keepdim=True)
        logits = logits - logits_max.detach()
        
        exp_logits = torch.exp(logits)
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
        
        loss = - (mask * log_prob).sum(1) / mask.sum(1)
        return loss.mean()
    
    def forward(self, features, labels, is_labeled):
        # 标记数据的监督对比损失
        labeled_mask = torch.where(is_labeled)[0]
        if len(labeled_mask) > 1:
            sup_loss = self.supervised_contrastive(
                features[labeled_mask], 
                labels[labeled_mask]
            )
        else:
            sup_loss = torch.tensor(0.0)
            
        # 未标记数据的噪声对比估计
        logits = torch.matmul(features, features.T) / self.temperature
        labels = torch.arange(logits.size(0)).to(logits.device)
        
        nce_loss = (self.cross_entropy(logits, labels) + 
                   self.cross_entropy(logits.T, labels)) / 2
        
        return 0.5*sup_loss + 0.5*nce_loss

# 半监督聚类模块
class SemiSupervisedKMeans:
    def __init__(self, n_known_classes):
        self.n_known = n_known_classes
        self.kmeans = KMeans()
        
    def fit(self, features, labels=None):
        # 使用标记数据初始化已知类中心
        known_features = features[labels != -1]
        known_labels = labels[labels != -1]
        
        # 计算已知类中心
        known_centers = []
        for c in range(self.n_known):
            mask = (known_labels == c)
            if mask.sum() > 0:
                known_centers.append(known_features[mask].mean(axis=0))
                
        # 估计未知类数量
        self.estimate_cluster_num(features)
        
        # 执行半监督K-means
        self.kmeans = KMeans(n_clusters=self.total_clusters, init=np.vstack([
            np.array(known_centers),
            self.initialize_unknown_centers(features)
        ]))
        
        self.kmeans.fit(features)
        
    def estimate_cluster_num(self, features):
        # 论文中的自适应估计方法
        candidate_n = [self.n_known + i for i in range(1, 6)]
        best_score = -np.inf
        best_n = self.n_known + 1
        
        for n in candidate_n:
            kmeans = KMeans(n_clusters=n).fit(features)
            score = silhouette_score(features, kmeans.labels_)
            if score > best_score:
                best_score = score
                best_n = n
                
        self.total_clusters = best_n
        
    def initialize_unknown_centers(self, features):
        # 基于特征空间的密度初始化
        distances = np.linalg.norm(features - features.mean(axis=0), axis=1)
        return features[np.argsort(distances)[-self.total_clusters:]]

# 训练流程
def train_gcd(model, dataloader, epochs=50):
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    loss_fn = GCDLoss()
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        
        for batch in dataloader:
            images, labels, is_labeled = batch
            features = model(images)
            
            loss = loss_fn(features, labels, is_labeled)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # 每个epoch后执行聚类和评估
        print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(dataloader):.4f}")
        evaluate_clustering(model, dataloader)
        
def evaluate_clustering(model, dataloader):
    model.eval()
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _ in dataloader:
            features = model(images)
            all_features.append(features.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            
    features = np.concatenate(all_features)
    labels = np.concatenate(all_labels)
    
    # 执行半监督聚类
    cluster_model = SemiSupervisedKMeans(n_known_classes=num_known_classes)
    cluster_model.fit(features, labels)
    
    # 计算聚类指标
    known_mask = (labels != -1)
    ari = adjusted_rand_score(labels[known_mask], cluster_model.kmeans.labels_[known_mask])
    print(f"Adjusted Rand Index (Known Classes): {ari:.4f}")

# 初始化已知类别数量（根据数据集实际情况设置）
num_known_classes = 89  # 示例值，需要根据实际数据集调整

# 开始训练
train_gcd(model, train_loader)