## Unified Class Mappings

In [1]:
import os
import shutil
from collections import defaultdict

root_dir = "datasets"  # replace with your root directory path

# ------------------------------------
# Class Mappings (normalize class names)
# ------------------------------------
class_mapping = {
    "Tomato Early blight leaf": "Early_blight",
    "Tomato_Early_blight": "Early_blight",

    "Tomato Septoria leaf spot": "Septoria_leaf_spot",
    "Tomato_Septoria_leaf_spot": "Septoria_leaf_spot",

    "Tomato leaf bacterial spot": "Bacterial_spot",
    "Tomato_Bacterial_spot": "Bacterial_spot",
    "Bacterial spot": "Bacterial_spot",

    "Tomato leaf late blight": "Late_blight",
    "Tomato_Late_blight": "Late_blight",
    "Late blight": "Late_blight",

    "Tomato mold leaf": "Leaf_Mold",
    "Tomato_Leaf_Mold": "Leaf_Mold",
    "Black mold": "Leaf_Mold",

    "Tomato leaf mosaic virus": "Mosaic_virus",
    "Tomato__Tomato_mosaic_virus": "Mosaic_virus",

    "Tomato leaf yellow virus": "Yellow_Leaf_Curl_Virus",
    "Tomato__Tomato_YellowLeaf__Curl_Virus": "Yellow_Leaf_Curl_Virus",

    "Tomato__Target_Spot": "Target_Spot",

    "Tomato_Spider_mites_Two_spotted_spider_mite": "Spider_mites",

    "Tomato_healthy": "Healthy",
    "health": "Healthy",

    "Gray spot": "Gray_spot",
    "powdery mildew": "Powdery_mildew"
}

# Reverse mapping: unified class → original classes
unified_to_original = defaultdict(list)
for orig, unified in class_mapping.items():
    unified_to_original[unified].append(orig)


## Combining Datasets

In [2]:
# Create combined dataset directory
combined_dir = os.path.join(root_dir, "combined_dataset")
os.makedirs(combined_dir, exist_ok=True)

# Traverse through datasets
for dataset_name in os.listdir(root_dir):
    dataset_path = os.path.join(root_dir, dataset_name)
    if not os.path.isdir(dataset_path) or dataset_name == "combined_dataset":
        continue

    for orig_class in os.listdir(dataset_path):
        orig_class_path = os.path.join(dataset_path, orig_class)
        if not os.path.isdir(orig_class_path):
            continue

        # Check if this class has a mapping
        if orig_class not in class_mapping:
            print(f"Skipping unmapped class: {dataset_name}/{orig_class}")
            continue

        unified_class = class_mapping[orig_class]
        class_out_dir = os.path.join(combined_dir, unified_class)
        os.makedirs(class_out_dir, exist_ok=True)

        for fname in os.listdir(orig_class_path):
            if fname.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif')):
                src_path = os.path.join(orig_class_path, fname)
                # ✅ New format: datasetname_unifiedclassname_fname
                new_fname = f"{dataset_name}_{unified_class}_{fname}"
                dst_path = os.path.join(class_out_dir, new_fname)
                shutil.copy2(src_path, dst_path)

print("✅ Combined dataset created successfully!")


Skipping unmapped class: PlantDoc-Combined/Tomato leaf
✅ Combined dataset created successfully!


>skipped PlantDoc-Combined/Tomato leaf as we dont know these are healthy leaves or diseased ones

## Count of images in each class of Combined Dataset

In [3]:
import os
from collections import Counter

# Path to your merged dataset
combined_dir = os.path.join(root_dir, "combined_dataset")

class_counts = Counter()

for class_name in os.listdir(combined_dir):
    class_dir = os.path.join(combined_dir, class_name)
    if not os.path.isdir(class_dir):
        continue
    
    num_images = len([
        f for f in os.listdir(class_dir)
        if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
    ])
    class_counts[class_name] = num_images

# Print class counts sorted by size
for cls, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"{cls}: {count}")


Yellow_Leaf_Curl_Virus: 3283
Bacterial_spot: 2344
Late_blight: 2118
Septoria_leaf_spot: 1919
Healthy: 1697
Spider_mites: 1676
Target_Spot: 1404
Leaf_Mold: 1110
Early_blight: 1083
Mosaic_virus: 427
Powdery_mildew: 157
Gray_spot: 84


> Removing Spider Mite as it is a pest infection not a disease
> Removing Powdery_mildew and Gray_Spot as they are unique classes that exists in only one dataset and they have very few number of images.

In [4]:
import shutil
# List of classes to remove
remove_classes = ["Spider_mites", "Powdery_mildew", "Gray_spot"]

for cls in remove_classes:
	folder_path = os.path.join(combined_dir, cls)
	if os.path.isdir(folder_path):
		shutil.rmtree(folder_path)
		print(f"Deleted folder: {folder_path}")
	else:
		print(f"Folder not found (skipped): {folder_path}")

Deleted folder: datasets\combined_dataset\Spider_mites
Deleted folder: datasets\combined_dataset\Powdery_mildew
Deleted folder: datasets\combined_dataset\Gray_spot


>Checking if the folders successfully deleted

In [5]:
import os
from collections import Counter

# Path to your merged dataset
combined_dir = os.path.join(root_dir, "combined_dataset")

class_counts = Counter()

for class_name in os.listdir(combined_dir):
    class_dir = os.path.join(combined_dir, class_name)
    if not os.path.isdir(class_dir):
        continue
    
    num_images = len([
        f for f in os.listdir(class_dir)
        if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
    ])
    class_counts[class_name] = num_images

# Print class counts sorted by size
for cls, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"{cls}: {count}")


Yellow_Leaf_Curl_Virus: 3283
Bacterial_spot: 2344
Late_blight: 2118
Septoria_leaf_spot: 1919
Healthy: 1697
Target_Spot: 1404
Leaf_Mold: 1110
Early_blight: 1083
Mosaic_virus: 427


## Class Balancing with same distributions of images from each dataset
> For classes greater than 2000 , we are downsampling to 2000 images with same distributon of dataset.
> For classes less than 2000 images, we are augmenting to 2000 images, augmnted images distribution w.r.t datasets is same as in before augmentation 

In [7]:
import os
import random
import shutil
from collections import defaultdict, Counter
from PIL import Image
from torchvision import transforms

# ---------------------------
# Paths
# ---------------------------
root_dir = "datasets/combined_dataset"
output_dir = "datasets/balanced_dataset"
os.makedirs(output_dir, exist_ok=True)

# ---------------------------
# Params
# ---------------------------
TARGET_SIZE = 2000
random.seed(42)

# ---------------------------
# Augmentation pipeline (mild + realistic)
# ---------------------------
augment = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.RandomVerticalFlip(p=0.1),   # less frequent than horizontal
		transforms.RandomPerspective(distortion_scale=0.1, p=0.3),  # mild warping
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    transforms.RandomResizedCrop(size=(224,224), scale=(0.9, 1.1)),  # mild zoom
    transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),  # small shifts
    transforms.ToTensor(),
    transforms.ToPILImage()
])

# ---------------------------
# Helper: extract dataset name
# ---------------------------
def get_dataset_name(filename):
    return filename.split("_")[0]  # e.g., PlantVillage_earlyblight_xxx.jpg

# ---------------------------
# Helper: pretty print distributions
# ---------------------------
def print_distribution(title, dataset_groups, total):
    print(f"  {title}:")
    for ds, count in dataset_groups.items():
        proportion = (count / total) * 100 if total > 0 else 0
        print(f"    {ds}: {count} ({proportion:.2f}%)")

# ---------------------------
# Main loop
# ---------------------------
for class_name in os.listdir(root_dir):
    class_path = os.path.join(root_dir, class_name)
    if not os.path.isdir(class_path):
        continue

    print(f"\n📂 Processing class: {class_name}")
    out_class_path = os.path.join(output_dir, class_name)
    os.makedirs(out_class_path, exist_ok=True)

    files = [f for f in os.listdir(class_path) if f.lower().endswith(('.jpg','.jpeg','.png'))]

    # Group by dataset
    dataset_groups = defaultdict(list)
    for f in files:
        dataset_groups[get_dataset_name(f)].append(f)

    total = len(files)
    before_counts = {ds: len(imgs) for ds, imgs in dataset_groups.items()}
    print_distribution("Before processing", before_counts, total)

    selected_files = []

    # ---------------------------
    # Case 1: >2000 → Downsample
    # ---------------------------
    if total > TARGET_SIZE:
        for ds, imgs in dataset_groups.items():
            ratio = len(imgs) / total
            k = round(ratio * TARGET_SIZE)
            chosen = random.sample(imgs, min(k, len(imgs)))
            selected_files.extend(chosen)

        # Adjust rounding
        if len(selected_files) > TARGET_SIZE:
            selected_files = selected_files[:TARGET_SIZE]
        elif len(selected_files) < TARGET_SIZE:
            remaining = [f for f in files if f not in selected_files]
            selected_files.extend(random.sample(remaining, TARGET_SIZE - len(selected_files)))

        # Save directly
        for f in selected_files:
            shutil.copy2(os.path.join(class_path, f), os.path.join(out_class_path, f))

    # ---------------------------
    # Case 2: <2000 → Augment
    # ---------------------------
    else:
        selected_files = files[:]  # keep originals
        for f in files:
            shutil.copy2(os.path.join(class_path, f), os.path.join(out_class_path, f))

        needed = TARGET_SIZE - total
        print(f"  Augmenting {needed} images...")

        for ds, imgs in dataset_groups.items():
            ratio = len(imgs) / total
            k = round(ratio * needed)
            for i in range(k):
                src_file = random.choice(imgs)
                img = Image.open(os.path.join(class_path, src_file)).convert("RGB")
                aug_img = augment(img)
                aug_name = f"{os.path.splitext(src_file)[0]}_aug{i}.jpg"
                aug_img.save(os.path.join(out_class_path, aug_name))
                selected_files.append(aug_name)

        # Fix rounding mismatch
        if len(selected_files) > TARGET_SIZE:
            selected_files = selected_files[:TARGET_SIZE]

    # ---------------------------
    # After distribution
    # ---------------------------
    after_counts = Counter([get_dataset_name(f) for f in selected_files])
    print_distribution("After processing", after_counts, len(selected_files))

print("\n✅ Balanced dataset (all classes = 2000 images) created at:", output_dir)



📂 Processing class: Bacterial_spot
  Before processing:
    PlantDoc-Combined: 107 (4.56%)
    PlantVillage: 2127 (90.74%)
    taiwan: 110 (4.69%)
  After processing:
    PlantDoc-Combined: 91 (4.55%)
    PlantVillage: 1815 (90.75%)
    taiwan: 94 (4.70%)

📂 Processing class: Early_blight
  Before processing:
    PlantDoc-Combined: 83 (7.66%)
    PlantVillage: 1000 (92.34%)
  Augmenting 917 images...
  After processing:
    PlantDoc-Combined: 153 (7.65%)
    PlantVillage: 1847 (92.35%)

📂 Processing class: Healthy
  Before processing:
    PlantVillage: 1591 (93.75%)
    taiwan: 106 (6.25%)
  Augmenting 303 images...
  After processing:
    PlantVillage: 1875 (93.75%)
    taiwan: 125 (6.25%)

📂 Processing class: Late_blight
  Before processing:
    PlantDoc-Combined: 111 (5.24%)
    PlantVillage: 1909 (90.13%)
    taiwan: 98 (4.63%)
  After processing:
    PlantDoc-Combined: 105 (5.25%)
    PlantVillage: 1803 (90.15%)
    taiwan: 92 (4.60%)

📂 Processing class: Leaf_Mold
  Before proce