In [14]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
print(f"TensorFlow Version: {tf.__version__}")

TensorFlow Version: 2.18.0


In [15]:
# rotation, zooming, shearing, and flipping as per base paper
data_augmentation = Sequential(
  [
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.2),
    layers.RandomZoom(0.2),
    layers.RandomShear(x_factor=(0.0, 0.2))
  ],
  name="data_augmentation",
)

In [16]:
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 32
NUM_CLASSES = 38
DATA_DIR = '../input/plantvillage'

print(f"Image size set to: {IMAGE_SIZE}")
print(f"Data directory is: {DATA_DIR}")

Image size set to: (256, 256)
Data directory is: ../input/plantvillage


In [17]:
train_dir = '../input/plantvillage/PlantVillage/train' #train path
val_dir = '../input/plantvillage/PlantVillage/val'  #validation path

#train dataset
train_ds = tf.keras.utils.image_dataset_from_directory(
    train_dir,  
    seed=123,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE
)

#validation dataset
val_ds = tf.keras.utils.image_dataset_from_directory(
    val_dir,
    seed=123,
    image_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE
)

class_names = train_ds.class_names
print(f"Loaded {len(class_names)} classes.")
print(f"First 5 classes: {class_names[:5]}...")

Found 43444 files belonging to 38 classes.
Found 10861 files belonging to 38 classes.
Loaded 38 classes.
First 5 classes: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy']...


In [19]:
#Load ResNet101 V2 base model (pre-trained on ImageNet)
base_model_resnet = tf.keras.applications.ResNet101V2(
    input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3),
    include_top=False, 
    weights='imagenet'
)

#Freeze the base model
base_model_resnet.trainable = False

#Create your new model on top
model_resnet = Sequential([
    # Input layer
    layers.Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),
    
    #rescaling pixels
    layers.Rescaling(1./255),
    
    # re-use the same data_augmentation layer
    data_augmentation,
    
    # The ResNet101 V2 base
    base_model_resnet,
    
    # The same classifier head
    layers.GlobalAveragePooling2D(),
    layers.Dense(NUM_CLASSES, activation='softmax') 
], name="ResNet101V2_Transfer_Learning")

#Compile the model
model_resnet.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model_resnet.summary()

In [20]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',  
    patience=2,              
    restore_best_weights=True 
)

In [22]:
epochs_to_run = 20 # No of epochs set to 20

history_resnet = model_resnet.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs_to_run,
    callbacks=[early_stopping_callback] 
)

print("ResNet101 V2 model training complete.")

Epoch 1/20
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m415s[0m 305ms/step - accuracy: 0.8159 - loss: 0.7019 - val_accuracy: 0.9188 - val_loss: 0.2616
Epoch 2/20
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m408s[0m 300ms/step - accuracy: 0.9215 - loss: 0.2600 - val_accuracy: 0.9311 - val_loss: 0.2132
Epoch 3/20
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m407s[0m 300ms/step - accuracy: 0.9362 - loss: 0.2024 - val_accuracy: 0.9344 - val_loss: 0.2046
Epoch 4/20
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m406s[0m 299ms/step - accuracy: 0.9444 - loss: 0.1759 - val_accuracy: 0.9400 - val_loss: 0.1764
Epoch 5/20
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m406s[0m 299ms/step - accuracy: 0.9467 - loss: 0.1640 - val_accuracy: 0.9482 - val_loss: 0.1602
Epoch 6/20
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m406s[0m 299ms/step - accuracy: 0.9502 - loss: 0.1524 - val_accuracy: 0.9407 - val_loss:

In [23]:
#Saving it without fine-tunning cause seems like a satisfying accuracy (Yeah 90+ isn't satisfying Gen_nil T-T) 
model_resnet.save("/kaggle/working/AI_resnet.h5")
print("Done saving..")

Done saving..
