In [7]:
import json
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# 设置数据集路径
images_dir = "E:/data1plantdocimages"
classes_file = "E:/data1plantdocclasses.txt"
prompts_file = "E:/data1plantdocplantwildprompts.json"
trainval_file = "E:/data1plantdoctrainval.txt"

# 定义图像和文本的预处理
image_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# 读取类别信息
with open(classes_file, 'r') as f:
    classes = {int(line.split(' ')[0]): line.split(' ')[1] for line in f}

# 读取训练集和验证集的划分
with open(trainval_file, 'r') as f:
    train_indices = set(map(int, f.read().split(',')))

# 读取文本描述
with open(prompts_file, 'r') as f:
    prompts = json.load(f)

# 创建数据集类
class PlantDataset(Dataset):
    def __init__(self, images_dir, prompts, classes, indices, transform=None):
        self.images_dir = images_dir
        self.prompts = prompts
        self.classes = classes
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        class_id = self.indices[idx]
        image_path = os.path.join(self.images_dir, str(class_id), random.choice(os.listdir(os.path.join(self.images_dir, str(class_id)))))
        image = Image.open(image_path).convert('RGB')
        text = self.prompts[str(class_id)][random.randint(0, 49)]
        if self.transform:
            image = self.transform(image)
        # Tokenize the text
        tokens = text.split()
        # Convert tokens to integers
        text_tensor = torch.tensor([self.classes.get(token, len(self.classes)) for token in tokens], dtype=torch.long)
        return image, text_tensor, class_id

# 创建数据集和数据加载器
train_dataset = PlantDataset(images_dir, prompts, classes, train_indices, transform=image_transforms)
val_dataset = PlantDataset(images_dir, prompts, classes, set(range(len(classes))) - train_indices, transform=image_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# 打印训练集和验证集的大小
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")


FileNotFoundError: [Errno 2] No such file or directory: 'E:/data1plantdocclasses.txt'