文件结构解析

In [1]:
import json
import os
from PIL import Image

# 解析类别文件
with open("E:/CPCI/plantdoc/classes.txt", "r") as f:
    classes = {line.split()[0]: line.split()[1] for line in f.readlines()}

# 解析数据划分文件
split_dict = {"0": "test", "1": "train", "2": "val"}
with open("E:/CPCI/plantdoc/trainval.txt", "r") as f:
    split_data = [line.strip().split('=') for line in f.readlines()]
    image_splits = {os.path.basename(item[0]): split_dict[item[2]] for item in split_data}

# 加载文本提示
with open("E:/CPCI/plantdoc/plantwild_prompts.json", "r") as f:
    text_prompts = json.load(f)  # 结构：{class_id: [prompt1, ..., prompt50]}

小样本数据生成器

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import random

class FewShotPlantDataset(Dataset):
    def __init__(self, root_dir, mode="train", n_way=5, k_shot=1):
        self.class_folders = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.selected_classes = random.sample(self.class_folders, n_way)
        self.samples = []
        
        for cls in self.selected_classes:
            cls_path = os.path.join(root_dir, cls)
            imgs = [os.path.join(cls_path, f) for f in os.listdir(cls_path) 
                   if image_splits.get(os.path.basename(f), "") == mode]
            selected_imgs = random.sample(imgs, k_shot)
            self.samples.extend([(img, cls) for img in selected_imgs])

    def __getitem__(self, idx):
        img_path, cls = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        text = random.choice(text_prompts[cls])  # 随机选择一条文本提示
        return image, text, cls

    def __len__(self):
        return len(self.samples)

跨模态数据增强

In [6]:
from torchvision import transforms
from nlpaug import Augmenter

# 图像增强管道
img_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                        (0.26862954, 0.26130258, 0.27577711))
])

# 文本增强（需先安装nlpaug）
text_aug = Augmenter('contextual_word_embs', action='substitute', model_path='E:\models\\bert_base_uncased')

def augment_text(text):
    return text_aug.augment(text)[0]


TypeError: __init__() got an unexpected keyword argument 'model_path'