02 — Data Pipeline (Dataset, Transforms, DataLoader)

This notebook builds the full data input pipeline for the PlantVillage classification project. Before training any deep learning model, we need a reliable way to:
 - load images from disk
 - apply preprocessing and transformations
 - convert images into tensors
 - create batches for efficient training
 - pair images with the correct class labels

In this notebook, we define a custom PyTorch Dataset class, set up training and validation transforms, and build DataLoader objects to feed data into our models. We also visualize sample batches to ensure that images, labels, and shapes are all correct.

In [None]:
import os
import matplotlib.pyplot as plt
# imports for data handling
import torch
from torch.utils.data import Dataset, DataLoader
#loads images and applies transforms
from torchvision import transforms
from PIL import Image

In [None]:
train_dir = "../data/PlantVillage/train"
val_dir = "../data/PlantVillage/val"

classes = sorted(os.listdir(train_dir))
print("Number of classes: ", len(classes))

In [None]:
#define transforms 
#training images = teaches models
#validation images = tests model if learning properly
train_transforms = transforms.Compose([
  transforms.Resize((224,224)), #size that ResNet needs
  transforms.ToTensor(), #converts image to tensor
  # images are made of pixels 0-255, converting to tensor turns it into values 0-1, floats
])

val_transforms = transforms.Compose([
  transforms.Resize((224,224)),
  transforms.ToTensor(),
])

In [None]:
"""
Custom dataset class that tells PyTorch:
  - how to find images
  - how to read them
  - how to apply transforms
  - how to return (image, label)
"""

class PlantVillageDataset(Dataset):
  def __init__(self, root_dir, classes, transform=None):
    self.root_dir = root_dir
    self.transform = transform
    self.classes = classes

    self.image_paths = []
    self.labels = []

    for idx, cls in enumerate(classes):
      class_path = os.path.join(root_dir, cls)
      img_files = os.listdir(class_path)

      for img_name in img_files:
        self.image_paths.append(os.path.join(class_path, img_name))
        self.labels.append(idx)

  def __len__(self):
    return len(self.image_paths)
  
  def __getitem__(self,index):

    #Step 1 -> load image
    img_path = self.image_paths[index]
    img = Image.open(img_path).convert("RGB")

    #Step 2 -> apply transform
    if self.transform:
      img = self.transform(img)
    
    #Step 3 -> return (img, label)
    label = self.labels[index]
    return (img, label)

In [None]:
train_dataset = PlantVillageDataset(train_dir, classes, transform=train_transforms)
val_dataset = PlantVillageDataset(val_dir, classes, transform=val_transforms)

print("Train samples:", len(train_dataset))
print("Val samples:", len(val_dataset))


In [None]:
# DataLoaders handle batching, shuffling, and efficient reading.
#shuffle=True -> prevents models from memorizing order

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [None]:
images, labels = next(iter(train_loader))

print("Batch image tensor shape:", images.shape)
print("Batch labels:", labels[:10])


In [None]:
def show_batch(images, labels, classes, n=8):
  plt.figure(figsize=(18, 6))
    
  for i in range(n):
    img = images[i].permute(1, 2, 0)  # Channel,Height,Width → HWC (for plotting)
        
    plt.subplot(1, n, i + 1)
    plt.imshow(img)
    plt.axis("off")
        
     # wrap long class names
    label = classes[labels[i]]
    plt.title(label, fontsize=8, wrap=True)
    
  plt.tight_layout()
  plt.show()


show_batch(images, labels, classes)
