In [1]:
!pip install -q kaggle
!pip install opendatasets

Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)
Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.22


In [2]:
import os
import zipfile
import shutil
import random
import json
import numpy as np
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.utils import load_img, img_to_array, array_to_img, save_img
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
from google.colab import files
import opendatasets as od

In [3]:
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"yonatanyi","key":"d12e204b815b2ff2d9d4045f0c121064"}'}

In [4]:
od.download(
    "https://www.kaggle.com/datasets/emmarex/plantdisease")

Dataset URL: https://www.kaggle.com/datasets/emmarex/plantdisease
Downloading plantdisease.zip to ./plantdisease


100%|██████████| 658M/658M [00:04<00:00, 158MB/s]







In [5]:
original_dir = '/content/plantdisease/plantvillage/PlantVillage'
split_dir = '/content/plantvillage_split'
split_ratio = 0.2
IMG_SIZE = (128, 128)

os.makedirs(split_dir, exist_ok=True)
for split in ['train', 'val']:
    os.makedirs(os.path.join(split_dir, split), exist_ok=True)

TOMATO_CLASSES = [
    'Tomato_Bacterial_spot',
    'Tomato_Early_blight',
    'Tomato_Late_blight',
    'Tomato_Septoria_leaf_spot',
    'Tomato_Spider_mites_Two_spotted_spider_mite',
    'Tomato__Target_Spot',
    'Tomato__Tomato_YellowLeaf__Curl_Virus',
    'Tomato__Tomato_mosaic_virus',
    'Tomato_healthy'
]

for class_name in os.listdir(original_dir):
    class_dir = os.path.join(original_dir, class_name)
    if not os.path.isdir(class_dir):
        continue
    images = [img for img in os.listdir(class_dir) if img.lower().endswith(('.jpg', '.jpeg', '.png'))]
    random.shuffle(images)
    split_idx = int(len(images) * (1 - split_ratio))
    train_imgs = images[:split_idx]
    val_imgs = images[split_idx:]
    if class_name in TOMATO_CLASSES:
      os.makedirs(os.path.join(split_dir, 'train', class_name), exist_ok=True)
      os.makedirs(os.path.join(split_dir, 'val', class_name), exist_ok=True)
      for img in train_imgs:
          shutil.copy(os.path.join(class_dir, img), os.path.join(split_dir, 'train', class_name, img))
      for img in val_imgs:
          shutil.copy(os.path.join(class_dir, img), os.path.join(split_dir, 'val', class_name, img))



def random_orient_augment(img_path, img_size):
    img = load_img(img_path, target_size=img_size)
    img_arr = img_to_array(img)
    if random.random() > 0.5:
        img_arr = tf.image.flip_left_right(img_arr)
    if random.random() > 0.5:
        img_arr = tf.image.flip_up_down(img_arr)
    k = random.randint(0, 3)
    img_arr = tf.image.rot90(img_arr, k=k)
    return array_to_img(img_arr)

for split in ['train', 'val']:
    src_dir = f'/content/plantvillage_split/{split}'
    dst_tomato = f'/content/data/{split}/'
    os.makedirs(dst_tomato, exist_ok=True)
    for class_name in os.listdir(src_dir):
        class_path = os.path.join(src_dir, class_name)
        if not os.path.isdir(class_path):
            continue
        for img in os.listdir(class_path):
            if not img.lower().endswith(('.jpg', '.jpeg', '.png')):
                continue
            src_img = os.path.join(class_path, img)
            try:
                aug_img = random_orient_augment(src_img, IMG_SIZE)
                if class_name in TOMATO_CLASSES:
                    dst = dst_tomato
                    new_img_name = f"{class_name}_{img}"
                    dst_img = os.path.join(dst, new_img_name)
                    aug_img.save(dst_img)
            except Exception as e:
                continue
print('Train and Test Split is Done!')

Train and Test Split is Done!


In [6]:
model = models.Sequential([
    layers.Conv2D(32, (3,3), activation='relu', input_shape=IMG_SIZE + (3,)),
    layers.MaxPooling2D(2,2),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),
    layers.Conv2D(128, (3,3), activation='relu'),
    layers.MaxPooling2D(2,2),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(64, activation='relu'),
    layers.Dropout(0.2),
    layers.Dense(len(TOMATO_CLASSES), activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [8]:
BATCH_SIZE = 32

train_dir = '/content/plantvillage_split/train'
val_dir = '/content/plantvillage_split/val'

train_ds = image_dataset_from_directory(
    train_dir,
    labels='inferred',
    label_mode='categorical',
    batch_size=BATCH_SIZE,
    image_size=IMG_SIZE,
    shuffle=True
)
val_ds = image_dataset_from_directory(
    val_dir,
    labels='inferred',
    label_mode='categorical',
    batch_size=BATCH_SIZE,
    image_size=IMG_SIZE,
    shuffle=False
)

Found 12043 files belonging to 9 classes.
Found 3016 files belonging to 9 classes.


In [9]:
EPOCHS = 20
history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)

Epoch 1/20
[1m377/377[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m256s[0m 672ms/step - accuracy: 0.3468 - loss: 5.3057 - val_accuracy: 0.6058 - val_loss: 1.1233
Epoch 2/20
[1m377/377[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m262s[0m 671ms/step - accuracy: 0.6336 - loss: 1.0696 - val_accuracy: 0.7745 - val_loss: 0.6646
Epoch 3/20
[1m377/377[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m268s[0m 687ms/step - accuracy: 0.7345 - loss: 0.7674 - val_accuracy: 0.7881 - val_loss: 0.6525
Epoch 4/20
[1m377/377[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m268s[0m 703ms/step - accuracy: 0.8049 - loss: 0.5623 - val_accuracy: 0.7905 - val_loss: 0.7178
Epoch 5/20
[1m377/377[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m313s[0m 682ms/step - accuracy: 0.8373 - loss: 0.4716 - val_accuracy: 0.8452 - val_loss: 0.5209
Epoch 6/20
[1m377/377[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m263s[0m 697ms/step - accuracy: 0.8773 - loss: 0.3597 - val_accuracy: 0.8657 - val_loss: 0.4790
Epoc

In [10]:
model.save('/content/tomato_disease_classifier.h5')
with open('/content/tomato_classes.json', 'w') as f:
    json.dump(TOMATO_CLASSES, f)
print('Tomato Disease Classifer Saved!')



Tomato Disease Classifer Saved!
