数据预处理与增强

In [6]:
import torch
from torchvision import transforms
from PIL import Image
import json
import os

# 配置路径
data_root = "E:/CPCI/plantdoc/"
image_dir = "E:/CPCI/plantdoc/images/"
split_file = data_root + "trainval.txt"
prompt_file = data_root + "plantwild_prompts.json"

# 数据增强
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据划分
def load_split(split_path):
    splits = {'train': [], 'val': [], 'test': []}
    with open(split_path) as f:
        for line in f:
            path, class_id, split_id = line.strip().split('=')
            split_type = ['test', 'train', 'val'][int(split_id)]
            splits[split_type].append((path, int(class_id)))
    return splits

# 加载文本提示
with open(prompt_file) as f:
    prompts = json.load(f)

# 自定义数据集
class PlantDataset(torch.utils.data.Dataset):
    def __init__(self, split_data, transform=None):
        self.samples = split_data
        self.transform = transform
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        path, class_id = self.samples[idx]
        img_path = os.path.join(image_dir, path)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        text_prompts = prompts[str(class_id)]  # 获取50个文本提示
        return image, class_id, text_prompts

多模态模型构建（使用CLIP + 微调）
模型下载链接：https://github.com/openai/CLIP （需安装clip包）

In [None]:
import clip
import torch
import torch.nn as nn
from transformers import CLIPTokenizer

class MultimodalClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # 加载本地预训练CLIP模型
        model_path = "E:\models\CLIP_VIT-L-14\pytorch_model.bin"
        config_path = "E:\models\CLIP_VIT-L-14\config.json"
        tokenizer_path = "E:\models\CLIP_VIT-L-14\\tokenizer.json"
        
        # 加载分词器
        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
        
        # 加载模型配置
        self.clip_model, _ = clip.load(config_path, device='cpu', jit=False)
        
        # 加载模型权重
        state_dict = torch.load(model_path, map_location='cpu')
        self.clip_model.load_state_dict(state_dict)
        
        # 冻结部分参数
        for param in self.clip_model.parameters():
            param.requires_grad = False
            
        # 替换最后的分类层
        self.image_fc = nn.Sequential(
            nn.Linear(1024, 256),  # 根据模型输出维度调整
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        self.text_fc = nn.Sequential(
            nn.Linear(1024, 256),  # 根据模型输出维度调整
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        self.final_classifier = nn.Linear(512, num_classes)

    def forward(self, images, texts):
        # 图像特征提取
        image_features = self.clip_model.encode_image(images)
        image_features = self.image_fc(image_features)
        
        # 文本特征提取（取平均）
        text_tokens = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(images.device)
        text_features = self.clip_model.encode_text(text_tokens.input_ids)
        text_features = self.text_fc(text_features)
        
        # 特征融合
        combined = torch.cat([image_features, text_features], dim=1)
        return self.final_classifier(combined)

训练策略优化

In [18]:
import clip
import torch
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPModel
from torch.utils.data import DataLoader
from transformers import AdamW

# 初始化
splits = load_split(split_file)
train_dataset = PlantDataset(splits['train'], image_transform)
val_dataset = PlantDataset(splits['val'], image_transform)

class MultimodalClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # 加载本地预训练CLIP模型
        model_path = "E:\models\CLIP_VIT-L-14\pytorch_model.bin"
        config_path = "E:\models\CLIP_VIT-L-14\config.json"
        tokenizer_path = "E:\models\CLIP_VIT-L-14"
        
        # 加载分词器
        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
        
        # 加载模型配置
        self.clip_model = CLIPModel.from_pretrained(tokenizer_path)
        
        # 加载模型权重
        state_dict = torch.load(model_path, map_location='cpu')
        
        # 移除不匹配的键
        keys_to_remove = ['text_model.embeddings.position_ids', 'vision_model.embeddings.position_ids']
        for key in keys_to_remove:
            if key in state_dict:
                del state_dict[key]
        
        self.clip_model.load_state_dict(state_dict, strict=False)
        
        # 冻结部分参数
        for param in self.clip_model.parameters():
            param.requires_grad = False
            
        # 替换最后的分类层
        self.image_fc = nn.Sequential(
            nn.Linear(768, 256),  # 根据模型输出维度调整
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        self.text_fc = nn.Sequential(
            nn.Linear(768, 256),  # 根据模型输出维度调整
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        self.final_classifier = nn.Linear(512, num_classes)

    def forward(self, images, texts):
        # 图像特征提取
        image_features = self.clip_model.vision_model(images).last_hidden_state[:, 0, :]
        image_features = self.image_fc(image_features)
        
        # 文本特征提取（取平均）
        text_tokens = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(images.device)
        text_features = self.clip_model.text_model(text_tokens.input_ids).last_hidden_state[:, 0, :]
        text_features = self.text_fc(text_features)
        
        # 特征融合
        combined = torch.cat([image_features, text_features], dim=1)
        return self.final_classifier(combined)

# 初始化模型、优化器和损失函数
model = MultimodalClassifier(num_classes=89)
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# 自定义collate_fn处理文本
def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    labels = torch.tensor([item[1] for item in batch])
    texts = [item[2][torch.randint(0, 50, (1,))[0]] for item in batch]  # 随机选一个提示
    return images, labels, texts

# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)

# 训练循环
def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    for images, labels, texts in loader:
        optimizer.zero_grad()
        outputs = model(images, texts)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# 验证循环
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels, texts in loader:
            outputs = model(images, texts)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

  state_dict = torch.load(model_path, map_location='cpu')


In [None]:
import clip
import torch
import torch.nn as nn
from transformers import CLIPTokenizer, CLIPModel
from torch.utils.data import DataLoader
from transformers import AdamW
import matplotlib.pyplot as plt

# 假设 load_split 和 PlantDataset 已经定义
# splits = load_split(split_file)
# train_dataset = PlantDataset(splits['train'], image_transform)
# val_dataset = PlantDataset(splits['val'], image_transform)

class PlantDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform
        self.prompts = {
            '0': ["prompt1", "prompt2", "prompt3"],  # 示例提示
            '1': ["promptA", "promptB", "promptC"],  # 示例提示
            # 添加其他类的提示
        }
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image_path, class_id, prompts = self.data[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        # 打印调试信息
        print(f"Class ID: {class_id}, Prompts: {self.prompts.keys()}")
        
        try:
            text_prompts = self.prompts[str(class_id)]  # 获取50个文本提示
        except KeyError:
            print(f"KeyError: Class ID {class_id} not found in prompts dictionary.")
            raise
        
        return image, class_id, text_prompts

# 初始化
splits = load_split(split_file)
train_dataset = PlantDataset(splits['train'], image_transform)
val_dataset = PlantDataset(splits['val'], image_transform)

class MultimodalClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # 加载本地预训练CLIP模型
        model_path = "E:\models\CLIP_VIT-L-14\pytorch_model.bin"
        config_path = "E:\models\CLIP_VIT-L-14\config.json"
        tokenizer_path = "E:\models\CLIP_VIT-L-14"
        
        # 加载分词器
        self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
        
        # 加载模型配置
        self.clip_model = CLIPModel.from_pretrained(tokenizer_path)
        
        # 加载模型权重
        state_dict = torch.load(model_path, map_location='cpu')
        
        # 移除不匹配的键
        keys_to_remove = ['text_model.embeddings.position_ids', 'vision_model.embeddings.position_ids']
        for key in keys_to_remove:
            if key in state_dict:
                del state_dict[key]
        
        self.clip_model.load_state_dict(state_dict, strict=False)
        
        # 冻结部分参数
        for param in self.clip_model.parameters():
            param.requires_grad = False
            
        # 替换最后的分类层
        self.image_fc = nn.Sequential(
            nn.Linear(768, 256),  # 根据模型输出维度调整
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        self.text_fc = nn.Sequential(
            nn.Linear(768, 256),  # 根据模型输出维度调整
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        self.final_classifier = nn.Linear(512, num_classes)

    def forward(self, images, texts):
        # 图像特征提取
        image_features = self.clip_model.vision_model(images).last_hidden_state[:, 0, :]
        image_features = self.image_fc(image_features)
        
        # 文本特征提取（取平均）
        text_tokens = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(images.device)
        text_features = self.clip_model.text_model(text_tokens.input_ids).last_hidden_state[:, 0, :]
        text_features = self.text_fc(text_features)
        
        # 特征融合
        combined = torch.cat([image_features, text_features], dim=1)
        return self.final_classifier(combined)

# 初始化模型、优化器和损失函数
model = MultimodalClassifier(num_classes=89)
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# 自定义collate_fn处理文本
def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    labels = torch.tensor([item[1] for item in batch])
    texts = [item[2][torch.randint(0, 50, (1,))[0]] for item in batch]  # 随机选一个提示
    return images, labels, texts

# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)

# 训练循环
def train_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    for images, labels, texts in loader:
        optimizer.zero_grad()
        outputs = model(images, texts)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# 验证循环
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels, texts in loader:
            outputs = model(images, texts)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# 训练和验证
num_epochs = 10
train_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer)
    val_accuracy = evaluate(model, val_loader)
    
    train_losses.append(train_loss)
    val_accuracies.append(val_accuracy)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

# 绘制训练损失和验证准确率曲线
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(val_accuracies, label='Val Accuracy', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

FileNotFoundError: [Errno 2] No such file or directory: 'path/to/split_file.csv'

创新改进点（需集成到代码中）：
文本提示增强：对每个样本随机选择3个提示进行特征融合
注意力融合模块：

In [19]:
class AttentionFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.image_proj = nn.Linear(dim, dim)
        self.text_proj = nn.Linear(dim, dim)
        self.attention = nn.MultiheadAttention(dim, num_heads=4)
        
    def forward(self, image_feat, text_feat):
        image_proj = self.image_proj(image_feat).unsqueeze(1)
        text_proj = self.text_proj(text_feat).unsqueeze(1)
        combined = torch.cat([image_proj, text_proj], dim=1)
        attn_output, _ = self.attention(combined, combined, combined)
        return attn_output.mean(dim=1)

模型保存与推理

In [20]:
def predict_single(image_path, model):
    image = Image.open(image_path).convert('RGB')
    image = image_transform(image).unsqueeze(0)
    class_id = 0  # 根据实际类别修改
    texts = prompts[str(class_id)][:5]  # 取前5个提示
    
    with torch.no_grad():
        output = model(image, texts)
    return torch.softmax(output, dim=1)