In [5]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing import image_dataset_from_directory

# --- Data Augmentation ---
data_augmentation = tf.keras.Sequential([
    layers.Rescaling(1./255),
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.2),
    layers.RandomContrast(0.2)
])

# --- Load Dataset ---
train_ds = image_dataset_from_directory(
    'train',
    labels='inferred',
    label_mode='categorical',
    image_size=(128, 128),
    batch_size=32,
    color_mode='rgb',
    shuffle=True
).map(lambda x, y: (data_augmentation(x), y))

val_ds = image_dataset_from_directory(
    'valid',
    labels='inferred',
    label_mode='categorical',
    image_size=(128, 128),
    batch_size=32,
    color_mode='rgb',
    shuffle=True
).map(lambda x, y: (x / 255.0, y))

# --- Get number of classes ---
num_classes = len(train_ds.element_spec[1].shape)  # or use: len(train_ds.class_names) if needed

# --- Improved Model ---
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(128, 128, 3)),
    layers.MaxPooling2D((2, 2)),
    
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    
    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(num_classes, activation='softmax')  # Auto-adjusts to number of classes
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

# --- Training ---
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=20
)

# --- Save the model ---
model.save("false_smut_model.keras")


Found 87 files belonging to 2 classes.
Found 49 files belonging to 2 classes.


Epoch 1/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 1s/step - accuracy: 0.4845 - loss: 0.7347 - val_accuracy: 0.6122 - val_loss: 0.6651
Epoch 2/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 932ms/step - accuracy: 0.8069 - loss: 0.4482 - val_accuracy: 0.6122 - val_loss: 0.6045
Epoch 3/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 861ms/step - accuracy: 0.8161 - loss: 0.3917 - val_accuracy: 0.7143 - val_loss: 0.5659
Epoch 4/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 866ms/step - accuracy: 0.8837 - loss: 0.3017 - val_accuracy: 0.7143 - val_loss: 0.5203
Epoch 5/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 984ms/step - accuracy: 0.9343 - loss: 0.2617 - val_accuracy: 0.6122 - val_loss: 0.6520
Epoch 6/20
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 977ms/step - accuracy: 0.8743 - loss: 0.2625 - val_accuracy: 0.7551 - val_loss: 0.5632
Epoch 7/20
[1m3/3[0m [32m━━━━━━━━━━━━━━