# Train a TensorFlow Model for Plant Disease Classification

This notebook trains a convolutional neural network (CNN) to classify plant diseases using the preprocessed PlantVillage dataset.
We will:
1. Import necessary libraries and configure TensorFlow for GPU or CPU usage.
2. Load the training and validation datasets.
3. Train the model and save it.

## Step 1: Imports & Setup

In [1]:
import os
import sys
import tensorflow as tf
import matplotlib.pyplot as plt

# Suppress TensorFlow warnings for cleaner output (really annoying)
tf.get_logger().setLevel('ERROR')

# Automatically detect GPU availability
gpus = tf.config.list_physical_devices('GPU')

if gpus:
    try:
        # Enable memory growth to avoid OOM errors
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(f"GPU detected: {len(gpus)} physical, {len(logical_gpus)} logical.")
    except RuntimeError as e:
        print("Error while configuring GPU:", e)
else:
    # Force TensorFlow to use CPU
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    print("No GPU detected. Using CPU mode.")

# Add the src folder to the Python path
sys.path.append('..')

sys.path.insert(0, '../src')
import config
from data_loader import get_train_val_ds
from train import train_model

# Show available devices
print("Devices available:", tf.config.list_physical_devices())

No GPU detected. Using CPU mode.
Devices available: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


## Step 2: Load Data

In [None]:
base_dir = config.DATA_PROCESSED_DIR / "PlantVillage"
train_dir = base_dir / "train"
val_dir = base_dir / "val"

print("Base:", base_dir, base_dir.exists())
print("Train:", train_dir.exists())
print("Val:", val_dir.exists())

train_ds, val_ds, class_names = get_train_val_ds()

# Rebuild class names from the train directory
class_names = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
num_classes = len(class_names)
print("Classes:", class_names)
print("Number of classes:", num_classes)

## Step 3: Train Model

In [None]:
model, history = train_model(
    train_ds,
    val_ds,
    num_classes=num_classes,
    epochs=config.EPOCHS,
    save_path=config.MODELS_DIR / "plant_disease.keras"
)

## Step 4: Inspect history

In [None]:
plt.plot(history.history["loss"], label="train loss")
plt.plot(history.history["val_loss"], label="val loss")
plt.legend()
plt.show()

plt.plot(history.history["accuracy"], label="train accuracy")
plt.plot(history.history["val_accuracy"], label="validation accuracy")
plt.legend()
plt.show()