In [4]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define dataset path
dataset_path = "Plant_Dataset"  # Change this if needed

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 for CNNs like ResNet
    transforms.ToTensor(),          # Convert to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
])

# Load dataset
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

# Split dataset into training (80%) and testing (20%)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Print dataset info
print(f"Total images: {len(dataset)}")
print(f"Training images: {len(train_dataset)}, Testing images: {len(test_dataset)}")
print(f"Class names: {dataset.classes}")


Total images: 19178
Training images: 15342, Testing images: 3836
Class names: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Background_without_leaves', 'Blueberry___healthy', 'Cherry___Powdery_mildew', 'Cherry___healthy', 'Corn___Cercospora_leaf_spot Gray_leaf_spot', 'Corn___Common_rust', 'Corn___Northern_Leaf_Blight', 'Corn___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)']


In [15]:
# Load the pretrained ResNet18 model
model = models.resnet18(pretrained=True)

# Modify the final fully connected layer to match the number of classes in your dataset (14)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 14)

# Move the model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
