02_data_pipeline.ipynb
- Dataset class
- Dataloaders
- Transforms
- Visual sanity-checks


In [1]:
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torchvision.transforms as T

class PlantBinaryDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row["image_path"]
        label = int(row["label"])

        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label

# Example transforms
train_transform = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

val_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = PlantBinaryDataset("CSV/plantvillage_train.csv", transform=train_transform)
val_dataset   = PlantBinaryDataset("CSV/plantvillage_val.csv",   transform=val_transform)
test_dataset  = PlantBinaryDataset("CSV/plantvillage_test.csv",  transform=val_transform)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=32, shuffle=False, num_workers=2)
