# üçö Rice Grain Classification - 4 Class Model
**Upload your `rice` folder to Google Drive first!**

In [None]:
# Step 1: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Step 2: Set your path (EDIT THIS!)
# Change this to where your rice folder is in Google Drive
BASE_PATH = '/content/drive/MyDrive/rice'  # <-- EDIT THIS PATH

import os
if os.path.exists(BASE_PATH):
    print(f"‚úÖ Found: {BASE_PATH}")
    for folder in sorted(os.listdir(BASE_PATH)):
        if os.path.isdir(os.path.join(BASE_PATH, folder)):
            count = len(os.listdir(os.path.join(BASE_PATH, folder)))
            print(f"   {folder}: {count} images")
else:
    print(f"‚ùå NOT FOUND: {BASE_PATH}")
    print("Upload your rice folder to Google Drive and update BASE_PATH!")

In [None]:
# Step 3: Setup
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

print(f"TensorFlow: {tf.__version__}")
print(f"GPU: {tf.config.list_physical_devices('GPU')}")

IMAGE_SIZE = (224, 224)
BATCH_SIZE = 16

In [None]:
# Step 4: Create Datasets
train_ds = tf.keras.utils.image_dataset_from_directory(
    BASE_PATH, validation_split=0.2, subset="training", seed=42,
    image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, label_mode='categorical')

test_ds = tf.keras.utils.image_dataset_from_directory(
    BASE_PATH, validation_split=0.2, subset="validation", seed=42,
    image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, label_mode='categorical')

CLASS_NAMES = train_ds.class_names
print(f"Classes: {CLASS_NAMES}")

# Split validation
val_ds = train_ds.take(max(1, int(0.2 * len(train_ds))))
train_ds = train_ds.skip(max(1, int(0.2 * len(train_ds))))

In [None]:
# Step 5: Preprocessing
def preprocess(img, lbl):
    return tf.cast(img, tf.float32), lbl

def augment(img, lbl):
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img, lbl

train_ds = train_ds.map(preprocess).map(augment).cache().shuffle(1000).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess).cache().prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess).cache().prefetch(tf.data.AUTOTUNE)

In [None]:
# Step 6: Build Model
base = keras.applications.EfficientNetB0(weights='imagenet', include_top=False, input_shape=(224,224,3))
base.trainable = False

model = keras.Sequential([
    base,
    layers.GlobalAveragePooling2D(),
    layers.BatchNormalization(),
    layers.Dropout(0.2),
    layers.Dense(256, activation='relu'),
    layers.Dropout(0.3),
    layers.Dense(len(CLASS_NAMES), activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
print(f"Model ready! Classes: {len(CLASS_NAMES)}")

In [None]:
# Step 7: Train Phase 1 (Frozen Base)
callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
    keras.callbacks.ModelCheckpoint('/content/drive/MyDrive/rice_phase1.keras', save_best_only=True)
]

history1 = model.fit(train_ds, validation_data=val_ds, epochs=25, callbacks=callbacks)
print(f"\n‚úÖ Phase 1 Done! Best Val Acc: {max(history1.history['val_accuracy']):.4f}")

In [None]:
# Step 8: Train Phase 2 (Fine-tuning)
base.trainable = True
for layer in base.layers[:-60]:
    layer.trainable = False

model.compile(optimizer=keras.optimizers.Adam(1e-4), loss='categorical_crossentropy', metrics=['accuracy'])

callbacks = [
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=12, restore_best_weights=True),
    keras.callbacks.ModelCheckpoint('/content/drive/MyDrive/rice_phase2.keras', save_best_only=True)
]

history2 = model.fit(train_ds, validation_data=val_ds, epochs=25, callbacks=callbacks)
print(f"\n‚úÖ Phase 2 Done! Best Val Acc: {max(history2.history['val_accuracy']):.4f}")

In [None]:
# Step 9: Evaluate & Save
results = model.evaluate(test_ds)
print(f"\nüìä Test Accuracy: {results[1]:.4f}")

# Save final model
model.save('/content/drive/MyDrive/efficientnet_rice_final_inference.keras')
print("\n‚úÖ Model saved to Google Drive!")

print("\n" + "="*50)
print("LABEL_MAP for process_image_updated.py:")
print("="*50)
print("LABEL_MAP = {")
for i, name in enumerate(CLASS_NAMES):
    print(f'    {i}: "{name}",')
print("}")