In [1]:
import torch
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
import torch.nn as nn
import torch.nn.functional as F

# 配置参数
config = {
    "data_path": "E:/CPCI/plantdoc",
    "model_path": "E:/models/CLIP_VIT-L-14",
    "batch_size": 32,
    "num_workers": 4,
    "num_prompts": 50,
    "image_size": 224,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

# 数据加载器
class PlantDataset(Dataset):
    def __init__(self, split_type='train'):
        self.split_map = {'train':1, 'val':2, 'test':0}
        with open(f"{config['data_path']}/trainval.txt") as f:
            self.items = [line.strip().split('=') for line in f]
        
        with open(f"{config['data_path']}/plantwild_prompts.json") as f:
            self.prompts = json.load(f)
            
        self.processor = CLIPProcessor.from_pretrained(config['model_path'])
        
    def __getitem__(self, idx):
        img_path, class_id, split_id = self.items[idx]
        img = Image.open(f"{config['data_path']}/images/{img_path}")
        
        # 文本处理
        class_name = self._get_class_name(class_id)
        text_inputs = self.processor(
            text=self.prompts[class_name][:config['num_prompts']],
            padding=True,
            return_tensors="pt"
        )
        
        # 图像处理
        image_inputs = self.processor(
            images=img,
            return_tensors="pt"
        )
        
        return {
            'image': image_inputs.pixel_values.squeeze(),
            'text': text_inputs.input_ids,
            'attention_mask': text_inputs.attention_mask,
            'label': int(class_id)
        }

# 改进模型
class EnhancedCLIP(nn.Module):
    def __init__(self):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(config['model_path'])
        self.cross_attn = nn.MultiheadAttention(embed_dim=768, num_heads=8)
        self.fc = nn.Linear(768*2, 89)  # 89 classes
        
    def forward(self, images, texts):
        # 图像特征
        image_features = self.clip.vision_model(images).last_hidden_state[:,0,:]
        
        # 文本特征
        text_features = self.clip.text_model(
            input_ids=texts.input_ids,
            attention_mask=texts.attention_mask
        ).last_hidden_state[:,0,:]
        
        # 跨模态注意力
        attn_out, _ = self.cross_attn(
            image_features.unsqueeze(1),
            text_features.unsqueeze(1),
            text_features.unsqueeze(1)
        )
        
        # 特征融合
        fused_features = torch.cat([
            image_features,
            attn_out.squeeze(1)
        ], dim=1)
        
        return self.fc(fused_features)

# 训练流程
def train():
    model = EnhancedCLIP().to(config['device'])
    dataset = PlantDataset(split_type='train')
    dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    
    for epoch in range(10):
        for batch in dataloader:
            images = batch['image'].to(config['device'])
            texts = batch['text'].to(config['device'])
            labels = batch['label'].to(config['device'])
            
            outputs = model(images, texts)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print(f"Epoch {epoch} Loss: {loss.item()}")

# 评估代码
def evaluate():
    model.eval()
    # 类似训练流程，计算准确率等指标

if __name__ == "__main__":
    train()
    evaluate()

  from .autonotebook import tqdm as notebook_tqdm


TypeError: object of type 'PlantDataset' has no len()

一、数据加载器优化（适配实际文件结构）

In [1]:
import os
import json
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
import torch.nn as nn
from torchvision import transforms

class PlantDataset(Dataset):
    def __init__(self, split_type='train'):
        """
        split_type: 'train'(1), 'val'(2), 'test'(0)
        """
        self.split_code = {'train':1, 'val':2, 'test':0}[split_type]
        self.data_root = "E:/CPCI/plantdoc"
        
        # 加载类别映射
        self.class_map = {}
        with open(os.path.join(self.data_root, 'classes.txt'), 'r') as f:
            for line in f:
                class_id, class_name = line.strip().split(' ', 1)
                self.class_map[int(class_id)] = class_name
        
        # 加载划分数据
        self.samples = []
        with open(os.path.join(self.data_root, 'trainval.txt'), 'r') as f:
            for line in f:
                parts = line.strip().split('=')
                if len(parts) < 3: continue
                img_relpath, class_id, split_id = parts[0], int(parts[1]), int(parts[2])
                if split_id == self.split_code:
                    self.samples.append((
                        os.path.join(self.data_root, 'images', img_relpath),
                        class_id
                    ))
        
        # 加载文本提示
        with open(os.path.join(self.data_root, 'plantwild_prompts.json'), 'r') as f:
            self.prompts = json.load(f)
        
        # CLIP预处理
        self.processor = CLIPProcessor.from_pretrained("E:/models/CLIP_VIT-L-14")
        
        # 数据增强
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                               (0.26862954, 0.26130258, 0.27577711))
        ])
        
        self.test_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                               (0.26862954, 0.26130258, 0.27577711))
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, class_id = self.samples[idx]
        
        # 加载图像
        img = Image.open(img_path).convert('RGB')
        transform = self.train_transform if self.split_code == 1 else self.test_transform
        image = transform(img)
        
        # 获取类别文本提示
        class_name = self.class_map[class_id]
        text_prompts = self.prompts.get(class_name, [class_name])[:50]  # 取前50个提示
        
        # 随机选择一个提示进行训练
        selected_prompt = text_prompts[torch.randint(0, len(text_prompts), (1,)).item()]
        
        return {
            'image': image,
            'text': selected_prompt,
            'label': class_id
        }

  from .autonotebook import tqdm as notebook_tqdm


支持多模态

In [2]:
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor

class EnhancedCLIP(nn.Module):
    def __init__(self):
        super().__init__()
        # 加载CLIP基础模型
        self.clip_model = CLIPModel.from_pretrained("E:/models/CLIP_VIT-L-14")
        self.clip_processor = CLIPProcessor.from_pretrained("E:/models/CLIP_VIT-L-14")
        self.tokenizer = self.clip_processor.tokenizer
        
        # 多模态融合层
        self.fusion = nn.Sequential(
            nn.Linear(768*2, 1024),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LayerNorm(512)
        )
        
        # 分类头
        self.classifier = nn.Linear(512, 89)
        
        # 可学习的提示参数
        self.prompt_weights = nn.Parameter(torch.ones(50))

    def forward(self, images, texts):
        # 图像特征
        image_features = self.clip_model.get_image_features(pixel_values=images)
        
        # 文本特征
        text_inputs = self.tokenizer(
            texts, 
            padding=True, 
            return_tensors="pt"
        ).to(images.device)
        text_features = self.clip_model.get_text_features(**text_inputs)
        
        # 特征融合
        fused = torch.cat([image_features, text_features], dim=1)
        fused = self.fusion(fused)
        
        return self.classifier(fused)

三、完整训练流程

In [5]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import CLIPModel, CLIPProcessor

class EnhancedCLIP(nn.Module):
    def __init__(self):
        super().__init__()
        # 加载CLIP基础模型
        self.clip_model = CLIPModel.from_pretrained("E:/models/CLIP_VIT-L-14")
        self.clip_processor = CLIPProcessor.from_pretrained("E:/models/CLIP_VIT-L-14")
        self.tokenizer = self.clip_processor.tokenizer
        
        # 多模态融合层
        self.fusion = nn.Sequential(
            nn.Linear(768*2, 1024),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LayerNorm(512)
        )
        
        # 分类头
        self.classifier = nn.Linear(512, 89)
        
        # 可学习的提示参数
        self.prompt_weights = nn.Parameter(torch.ones(50))

    def forward(self, images, texts):
        # 图像特征
        image_features = self.clip_model.get_image_features(pixel_values=images)
        
        # 文本特征
        text_inputs = self.tokenizer(
            texts, 
            padding=True, 
            return_tensors="pt"
        ).to(images.device)
        text_features = self.clip_model.get_text_features(**text_inputs)
        
        # 特征融合
        fused = torch.cat([image_features, text_features], dim=1)
        fused = self.fusion(fused)
        
        return self.classifier(fused)

def train():
    # 配置参数
    config = {
        "batch_size": 16,  # 减少批量大小
        "accumulation_steps": 2,  # 梯度累积步骤
        "lr": 1e-5,
        "epochs": 20,
        "device": "cuda" if torch.cuda.is_available() else "cpu"
    }
    
    # 初始化
    model = EnhancedCLIP().to(config['device'])
    train_set = PlantDataset(split_type='train')
    val_set = PlantDataset(split_type='val')
    
    train_loader = DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=config['batch_size'], num_workers=4, pin_memory=True)
    
    optimizer = torch.optim.AdamW([
        {'params': model.clip_model.parameters(), 'lr': 1e-6},
        {'params': model.fusion.parameters()},
        {'params': model.classifier.parameters()}
    ], lr=config['lr'])
    
    criterion = nn.CrossEntropyLoss()
    
    # 训练循环
    best_acc = 0
    for epoch in range(config['epochs']):
        model.train()
        total_loss = 0
        
        for i, batch in enumerate(train_loader):
            images = batch['image'].to(config['device'])
            texts = batch['text']
            labels = batch['label'].to(config['device'])
            
            # 前向传播
            outputs = model(images, texts)
            loss = criterion(outputs, labels)
            
            # 反向传播
            loss = loss / config['accumulation_steps']  # 梯度累积
            loss.backward()
            
            if (i + 1) % config['accumulation_steps'] == 0:
                optimizer.step()
                optimizer.zero_grad()
            
            total_loss += loss.item() * config['accumulation_steps']
        
        # 验证
        val_acc = evaluate(model, val_loader, config['device'])
        print(f"Epoch {epoch+1}/{config['epochs']} | Loss: {total_loss/len(train_loader):.4f} | Val Acc: {val_acc:.2%}")
        
        # 保存最佳模型
        if val_acc > best_acc:
            torch.save(model.state_dict(), 'best_model.pth')
            best_acc = val_acc

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in loader:
            images = batch['image'].to(device)
            texts = batch['text']
            labels = batch['label'].to(device)
            
            outputs = model(images, texts)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    return correct / total

if __name__ == "__main__":
    train()

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
