In [46]:
# ! pip install -q --upgrade pillow
# ! pip install -q --upgrade tqdm
# ! pip install -q --upgrade torch
# ! pip install -q --upgrade torchvision
# ! pip install -q --upgrade "transformers[torch]"

In [47]:
import random
from pathlib import Path
from PIL import Image, UnidentifiedImageError
import shutil
from tqdm import tqdm, trange
from torchvision import transforms
import itertools

In [48]:
# Define paths
train_dir = Path('dataset/rice_leaf_disease_raw')
output_dir = Path('dataset/rice_leaf_disease_split')

# Dataset Augmentation

In [49]:
def is_image_file(file_path):
    try:
        Image.open(file_path).verify()
        return True
    except (UnidentifiedImageError, OSError):
        return False

In [50]:
def augment_image(image): 
    augmentation_pipeline = transforms.Compose([
        # Geometric Transformations
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.Pad(500, padding_mode="reflect"),
        transforms.RandomRotation(degrees=30),
        transforms.CenterCrop(min(image.size)),
        
        # Photometric Transformations
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.05),
        transforms.RandomAdjustSharpness(sharpness_factor=2),
        transforms.RandomAutocontrast(p=0.5),
        
        # Noise-based Transformations
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=5)], p=0.5),
    ])
    return augmentation_pipeline(image)

In [51]:
def copy_data_directory(src_dir, dest_dir):
    description = "Copying and Renaming Original Images"
    directories = [d for d in src_dir.iterdir() if d.is_dir()]

    for category in tqdm(directories, desc=description):
        dest_category_dir = dest_dir / category.name.title()
        dest_category_dir.mkdir(parents=True, exist_ok=True)

        for idx, image in enumerate(category.glob('*')):
            if is_image_file(image):
                img = Image.open(image).convert("RGB")
                img = transforms.CenterCrop(min(img.size))(img)
                img.save(dest_category_dir / f"{category.name.lower().replace(' ', '_')}_{idx}.jpg")

In [52]:
def split_dataset(src_dir, dest_dir, train_ratio=0.6, test_ratio=0.2, val_ratio=0.2):
    assert train_ratio + test_ratio + val_ratio == 1, "Ratios must sum to 1."

    src_dir = Path(src_dir)
    dest_dir = Path(dest_dir)

    train_dir = dest_dir / "train"
    test_dir = dest_dir / "test"
    val_dir = dest_dir / "validation"

    # Create destination folders
    for directory in [train_dir, test_dir, val_dir]:
        directory.mkdir(parents=True, exist_ok=True)

    for category in src_dir.iterdir():
        if category.is_dir():
            images = list(category.glob("*"))
            images = [img for img in images if is_image_file(img)]
            random.shuffle(images)

            train_split = int(len(images) * train_ratio)
            test_split = int(len(images) * (train_ratio + test_ratio))

            train_images = images[:train_split]
            test_images = images[train_split:test_split]
            val_images = images[test_split:]

            # Copy images to respective folders
            for img_set, dest in zip([train_images, test_images, val_images], [train_dir, test_dir, val_dir]):
                category_dest = dest / category.name
                category_dest.mkdir(parents=True, exist_ok=True)

                for idx, img in enumerate(img_set):
                    img = Image.open(img).convert("RGB")
                    img = transforms.CenterCrop(min(img.size))(img)
                    img.save(category_dest / f"{category.name.lower().replace(' ', '_')}_{idx}.jpg")

In [53]:
def augment_images_in_dir(dir_path, target_count):
    description = "Augmenting Classes"
    directories = list(dir_path.iterdir())
    for category in directories:
        images = list(category.glob('*'))
        itr = itertools.cycle(images)

        desc = f"Augmenting {category.name}"
        total = target_count - len(images)
        aug_count = {img_file: 0 for img_file in images}
        for i in trange(total, desc=desc):
            img_file = next(itr)
            img = Image.open(img_file).convert("RGB")
            augmented_img = augment_image(img)

            aug_count[img_file] += 1
            new_filename = f"{img_file.stem}_aug_{aug_count[img_file]}.jpg"
            augmented_img.save(category / new_filename)

In [54]:
# Remove existing output directory if present
if output_dir.exists():
    shutil.rmtree(output_dir)

# Copy original images
split_dataset(train_dir, output_dir)

# Augment images to reach a target count per category
augment_images_in_dir(Path(output_dir, "train"), target_count=1000)

print("Dataset augmentation completed and saved in the output directory!")

Augmenting Narrow Brown Leaf Spot: 100%|██████████| 934/934 [02:34<00:00,  6.05it/s]
Augmenting Rice Hispa: 100%|██████████| 878/878 [02:35<00:00,  5.66it/s]
Augmenting Sheath Blight: 100%|██████████| 838/838 [02:14<00:00,  6.25it/s]
Augmenting Leaf Blast: 100%|██████████| 822/822 [02:18<00:00,  5.96it/s]
Augmenting Bacterial Leaf Blight: 100%|██████████| 892/892 [02:25<00:00,  6.14it/s]
Augmenting Healthy Rice Leaf: 100%|██████████| 906/906 [04:44<00:00,  3.19it/s]
Augmenting Brown Spot: 100%|██████████| 844/844 [02:17<00:00,  6.15it/s]
Augmenting Leaf scald: 100%|██████████| 893/893 [02:26<00:00,  6.09it/s]

Dataset augmentation completed and saved in the output directory!



