In [1]:
!pip install -q transformers datasets timm


In [4]:
import os
import torch
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor

# Define dataset path
dataset_path = "/Plant_Disease_Dataset"

# Load ViT feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

# Get class names from dataset folders
classes = sorted(os.listdir(dataset_path))  # Ensure consistent class order

# Custom dataset class
class PlantDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Collect images and labels
        for label, class_name in enumerate(classes):
            class_path = os.path.join(root_dir, class_name)
            if os.path.isdir(class_path):
                for img_name in os.listdir(class_path):
                    img_path = os.path.join(class_path, img_name)
                    self.image_paths.append(img_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image) if self.transform else image
        label = self.labels[idx]
        return image, torch.tensor(label)

# Load dataset
dataset = PlantDataset(dataset_path, transform)

# Split into train and test sets (80-20 split)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Print sample info
print(f"Classes: {classes}")
print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

FileNotFoundError: [WinError 3] The system cannot find the path specified: '/Plant_Disease_Dataset'