In [None]:
! pip install --quiet datasets
! pip install --quiet tabulate
! pip install --quiet torchvision

In [None]:
import random
from pathlib import Path
from PIL import Image, UnidentifiedImageError
import shutil
from tqdm import tqdm
from torchvision import transforms

In [None]:
# Define paths
train_dir = Path('dataset/rice-leaf-disease-raw')
output_dir = Path('dataset/rice-leaf-disease-augmented')

# Dataset Augmentation

In [None]:
# Create an augmentation pipeline using torchvision
augmentation_pipeline = transforms.Compose([
    # Geometric Transformations
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=30),
    transforms.RandomAffine(degrees=0, scale=(1, 1.3), shear=30),
    
    # Photometric Transformations
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.05),
    
    # Noise-based Transformations
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=5)], p=0.5),
    
    # Convert to tensor and back to PIL image to maintain compatibility
    transforms.ToTensor(),
    transforms.ToPILImage()
])

# Helper function to check if a file is an image
def is_image_file(file_path):
    return file_path.suffix.lower() in {'.jpg', '.jpeg', '.png'}

# Function to copy and rename original images
def copy_and_rename_images(src_dir, dest_dir):
    for category in tqdm(list(src_dir.iterdir()), desc="Copying and Renaming Original Images"):
        if category.is_dir():
            dest_category_dir = dest_dir / category.name
            dest_category_dir.mkdir(parents=True, exist_ok=True)

            for idx, img_file in enumerate(category.glob('*'), 1):
                if is_image_file(img_file):
                    try:
                        Image.open(img_file).verify()  # Ensure file is an image
                        new_filename = f"{category.name}_{idx}.jpg"
                        shutil.copy(img_file, dest_category_dir / new_filename)
                    except UnidentifiedImageError:
                        print(f"Skipping non-image file: {img_file}")

# Function to augment images and save them
def augment_images_in_dir(dir_path, target_count):
    for category in tqdm(list(dir_path.iterdir()), desc="Augmenting Images"):
        if category.is_dir():
            images = [img for img in category.glob('*') if is_image_file(img)]
            count = len(images)

            with tqdm(total=target_count - count, desc=f"Augmenting {category.name}", leave=False) as pbar:
                while count < target_count:
                    img_file = random.choice(images)
                    try:
                        img = Image.open(img_file).convert("RGB")
                        augmented_img = augmentation_pipeline(img)
                        new_filename = f"{category.name}_aug_{count + 1}.jpg"
                        augmented_img.save(category / new_filename)
                        count += 1
                        pbar.update(1)
                    except UnidentifiedImageError:
                        print(f"Skipping non-image file: {img_file}")

# Remove existing output directory if present
if output_dir.exists():
    shutil.rmtree(output_dir)

# Copy original images
copy_and_rename_images(train_dir, output_dir)

# Augment images to reach a target count per category
augment_images_in_dir(output_dir, target_count=1000)

print("Dataset augmentation completed and saved in the output directory!")