In [2]:
import tensorflow as tf
import os
import numpy as np
import shutil
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split

# --- STEP 1: SPLIT DATA INTO TRAIN, VAL, TEST ---
data_dir = "Chinese-Herbs-Dataset"
output_dir = "processed_data"

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    os.makedirs(os.path.join(output_dir, "train"))
    os.makedirs(os.path.join(output_dir, "val"))
    os.makedirs(os.path.join(output_dir, "test"))

    class_names = os.listdir(data_dir)
    for class_name in class_names:
        class_path = os.path.join(data_dir, class_name)
        if not os.path.isdir(class_path):
            continue

        images = os.listdir(class_path)
        train_imgs, temp_imgs = train_test_split(images, test_size=0.3, random_state=42)
        val_imgs, test_imgs = train_test_split(temp_imgs, test_size=0.5, random_state=42)

        for split, img_list in zip(["train", "val", "test"], [train_imgs, val_imgs, test_imgs]):
            split_dir = os.path.join(output_dir, split, class_name)
            os.makedirs(split_dir, exist_ok=True)
            for img in img_list:
                shutil.copy(os.path.join(class_path, img), os.path.join(split_dir, img))

# --- STEP 2: DATA AUGMENTATION ---
img_size = 224
batch_size = 32

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.3,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    fill_mode="nearest"
)

val_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    os.path.join(output_dir, "train"),
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode="categorical"
)

val_generator = val_datagen.flow_from_directory(
    os.path.join(output_dir, "val"),
    target_size=(img_size, img_size),
    batch_size=batch_size,
    class_mode="categorical"
)

# --- STEP 3: BUILD MODEL ---
base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(img_size, img_size, 3))
base_model.trainable = False  # Freeze initial layers

x = GlobalAveragePooling2D()(base_model.output)
x = BatchNormalization()(x)
x = Dense(512, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(256, activation="relu")(x)
x = Dropout(0.4)(x)
output_layer = Dense(train_generator.num_classes, activation="softmax")(x)

model = Model(inputs=base_model.input, outputs=output_layer)

# --- STEP 4: COMPILE & TRAIN ---
model.compile(optimizer=Adam(learning_rate=0.0005), loss="categorical_crossentropy", metrics=["accuracy"])

early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True)

history = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=10,
    callbacks=[early_stopping],
    verbose=1
)

# --- STEP 5: FINE-TUNE ---
base_model.trainable = True  # Unfreeze for fine-tuning

model.compile(optimizer=Adam(learning_rate=0.0001), loss="categorical_crossentropy", metrics=["accuracy"])

history_fine = model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=10,
    callbacks=[early_stopping],
    verbose=1
)

# --- STEP 6: SAVE MODEL ---
model.save("herb.h5")


Found 140 images belonging to 20 classes.
Found 20 images belonging to 20 classes.
Epoch 1/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 2s/step - accuracy: 0.0406 - loss: 3.4080 - val_accuracy: 0.0500 - val_loss: 3.0784
Epoch 2/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 3s/step - accuracy: 0.1033 - loss: 2.9757 - val_accuracy: 0.1000 - val_loss: 3.0203
Epoch 3/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 3s/step - accuracy: 0.1270 - loss: 2.8942 - val_accuracy: 0.1000 - val_loss: 2.9964
Epoch 4/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 3s/step - accuracy: 0.0916 - loss: 2.9648 - val_accuracy: 0.1000 - val_loss: 2.9871
Epoch 5/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 3s/step - accuracy: 0.1604 - loss: 2.7619 - val_accuracy: 0.1000 - val_loss: 2.9725
Epoch 6/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 3s/step - accuracy: 0.1727 - loss: 2.7572 - val_accur

