In [None]:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
def create_model():
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(3, activation='softmax')(x)  # 3 classes: bacterial_leaf_blight, brown_spot, leaf_smut
    model = Model(inputs=base_model.input, outputs=x)
    for layer in base_model.layers:
        layer.trainable = False
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

In [None]:
def train_model():
    train_dir = "data/processed/train"
    val_dir = "data/processed/validation"
    train_datagen = ImageDataGenerator(rescale=1./255)
    val_datagen = ImageDataGenerator(rescale=1./255)
    train_generator = train_datagen.flow_from_directory(
        train_dir, target_size=(224, 224), batch_size=32, class_mode='categorical')
    val_generator = val_datagen.flow_from_directory(
        val_dir, target_size=(224, 224), batch_size=32, class_mode='categorical')
    model = create_model()
    checkpoint = tf.keras.callbacks.ModelCheckpoint('models/trained_model.h5', save_best_only=True, monitor='val_accuracy')
    model.fit(train_generator, epochs=20, validation_data=val_generator, callbacks=[checkpoint])
    return model

In [None]:
# Train the model
model = train_model()
print("Training completed. Model saved to models/trained_model.h5")