In [1]:
!pip install kaggle



In [2]:
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"belalsaeid","key":"cb5572dba32108d95580109d6e1b200c"}'}

In [3]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [4]:
!kaggle datasets download -d abdallahalidev/plantvillage-dataset
from zipfile import ZipFile

# Unzip the downloaded dataset
with ZipFile("plantvillage-dataset.zip", 'r') as zip_ref:
    zip_ref.extractall()

Dataset URL: https://www.kaggle.com/datasets/abdallahalidev/plantvillage-dataset
License(s): CC-BY-NC-SA-4.0
Downloading plantvillage-dataset.zip to /content
 99% 2.02G/2.04G [00:14<00:00, 43.3MB/s]
100% 2.04G/2.04G [00:14<00:00, 146MB/s] 


In [None]:
!pip install fastapi uvicorn pyngrok nest_asyncio tensorflow pillow numpy python-multipart

from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import nest_asyncio
import uvicorn
from pyngrok import ngrok
import os
import random
import shutil
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

# Define constants
img_size = 256
batch_size = 32
MODEL_PATH = "/content/drive/MyDrive/Colab Notebooks/model_epoch_20.h5"

# Create and compile the model
def create_model(num_classes):
    model = models.Sequential([
        layers.Conv2D(32, (3,3), activation='relu', input_shape=[img_size, img_size, 3]),
        layers.MaxPooling2D(2,2),
        layers.Conv2D(64, (3,3), activation='relu'),
        layers.MaxPooling2D(2,2),
        layers.Flatten(),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.5),  # New dropout layer added to reduce overfitting
        layers.Dense(num_classes, activation='softmax')
    ])
    from tensorflow.keras.optimizers import Adam
    optimizer = Adam(learning_rate=0.0001)  # Lower learning rate for more stable training
    model.compile(optimizer=optimizer, metrics=['accuracy'], loss='categorical_crossentropy')
    return model


# Image preprocessing function
def load_and_preprocess(image_path):
    img = Image.open(image_path).convert('RGB').resize([img_size, img_size])
    img_array = np.array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255
    return img_array

def predict_image_class(model, image_path, class_indices):
    preprocess_img = load_and_preprocess(image_path)
    predictions = model.predict(preprocess_img)[0]
    predicted_class_index = np.argmax(predictions)
    return class_indices[predicted_class_index]


# Setup data generators
def setup_data_generators():
    datagen = ImageDataGenerator(
        rescale=1./255,
        validation_split=0.2,
        rotation_range=20,          # Randomly rotate images by up to 20 degrees
        width_shift_range=0.1,        # Randomly shift images horizontally by 10%
        height_shift_range=0.1,       # Randomly shift images vertically by 10%
        shear_range=0.1,              # Apply shearing transformations
        zoom_range=0.1,               # Randomly zoom in on images
        horizontal_flip=True          # Randomly flip images horizontally
    )
    dataset_path = "/content/plantvillage dataset/color"
    train_gen = datagen.flow_from_directory(
        dataset_path, target_size=(img_size, img_size), batch_size=batch_size,
        class_mode='categorical', subset='training'
    )
    validation_gen = datagen.flow_from_directory(
        dataset_path, target_size=(img_size, img_size), batch_size=batch_size,
        class_mode='categorical', subset='validation'
    )
    return train_gen, validation_gen


# Create FastAPI app with CORS enabled
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins (replace "*" with your web app's domain for production)
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Set this flag to True if you want to train the model for additional epochs in this session.
# Set it to False if you only want to load the saved model for the API.
should_train = False  # Change to True to train for additional epochs

if should_train:
# Setup data generators
    train_gen, validation_gen = setup_data_generators()

    # If a saved model exists, load it. Otherwise, create a new one.
    if os.path.exists(MODEL_PATH):
        print("Loading saved model for training...")
        model = tf.keras.models.load_model(MODEL_PATH)
    else:
        print("Training model from scratch...")
        model = create_model(train_gen.num_classes)

    # Define number of epochs
    epochs = 100

    # Setup callbacks
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "/content/drive/MyDrive/Colab Notebooks/model_epoch_{epoch:02d}.h5",
    save_best_only=False,
    save_weights_only=False
)
    early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

    try:
        # Train the model
        model.fit(
            train_gen,
            steps_per_epoch=train_gen.samples // batch_size,
            epochs=epochs,
            validation_data=validation_gen,
            validation_steps=validation_gen.samples // batch_size,
            callbacks=[checkpoint_cb, early_stop]
        )
    except KeyboardInterrupt:
        print("Training interrupted manually. Saving model...")
        model.save(MODEL_PATH)
        print("Model saved.")
else:
    # Setup data generators for inference
    train_gen, _ = setup_data_generators()

    # If not training, just load the model for API usage.
    if os.path.exists(MODEL_PATH):
        print("Loading saved model for API usage...")
        model = tf.keras.models.load_model(MODEL_PATH)
    else:
        print("Model not found. Training model from scratch (only for API usage)...")
        model = create_model(train_gen.num_classes)


# Create class indices for mapping predictions
class_indices = {v: k for k, v in train_gen.class_indices.items()}

# Define the prediction endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        temp_dir = "temp_uploads"
        os.makedirs(temp_dir, exist_ok=True)
        temp_path = os.path.join(temp_dir, file.filename)
        with open(temp_path, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
        predict_name = predict_image_class(model, temp_path, class_indices)
        data = predict_name.split("___")
        info = {
            "name": " ".join(data[0].split("_")),
            "condition": " ".join(data[1].split("_"))
        }
        return info
    except Exception as e:
        return JSONResponse(status_code=500, content={"message": f"Error processing image: {str(e)}"})


# Start the server with ngrok
nest_asyncio.apply()
ngrok.set_auth_token("2yMMFAjAP2weUpE1Xj4cg5WHaBj_2SqYRUmg5rW1vYmWt1CL3")
public_url = ngrok.connect(8000)
print(f"API is running at: {public_url}")
uvicorn.run(app, host="0.0.0.0", port=8000)

ERROR:asyncio:Task exception was never retrieved
future: <Task finished name='Task-23' coro=<Server.serve() done, defined at /usr/local/lib/python3.11/dist-packages/uvicorn/server.py:68> exception=KeyboardInterrupt()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/uvicorn/main.py", line 580, in run
    server.run()
  File "/usr/local/lib/python3.11/dist-packages/uvicorn/server.py", line 66, in run
    return asyncio.run(self.serve(sockets=sockets))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nest_asyncio.py", line 30, in run
    return loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nest_asyncio.py", line 92, in run_until_complete
    self._run_once()
  File "/usr/local/lib/python3.11/dist-packages/nest_asyncio.py", line 133, in _run_once
    handle._run()
  File "/usr/lib/python3.11/asyncio/events.py", line 84, in _run
    s

Found 43456 images belonging to 38 classes.
Found 10849 images belonging to 38 classes.
Training model from scratch...


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


Epoch 1/100
[1m 460/1358[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m1:14[0m 83ms/step - accuracy: 0.2659 - loss: 2.8973

ERROR:asyncio:Task exception was never retrieved
future: <Task finished name='Task-1' coro=<Server.serve() done, defined at /usr/local/lib/python3.11/dist-packages/uvicorn/server.py:68> exception=KeyboardInterrupt()>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/uvicorn/main.py", line 580, in run
    server.run()
  File "/usr/local/lib/python3.11/dist-packages/uvicorn/server.py", line 66, in run
    return asyncio.run(self.serve(sockets=sockets))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nest_asyncio.py", line 30, in run
    return loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/nest_asyncio.py", line 92, in run_until_complete
    self._run_once()
  File "/usr/local/lib/python3.11/dist-packages/nest_asyncio.py", line 133, in _run_once
    handle._run()
  File "/usr/lib/python3.11/asyncio/events.py", line 84, in _run
    se

[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step - accuracy: 0.3811 - loss: 2.3865



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 99ms/step - accuracy: 0.3812 - loss: 2.3861 - val_accuracy: 0.7453 - val_loss: 0.9434
Epoch 2/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.6590 - loss: 1.2096



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m137s[0m 99ms/step - accuracy: 0.6590 - loss: 1.2095 - val_accuracy: 0.8252 - val_loss: 0.6371
Epoch 3/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.7279 - loss: 0.9265



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m116s[0m 85ms/step - accuracy: 0.7279 - loss: 0.9264 - val_accuracy: 0.8414 - val_loss: 0.5466
Epoch 4/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step - accuracy: 0.7722 - loss: 0.7547



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 104ms/step - accuracy: 0.7722 - loss: 0.7547 - val_accuracy: 0.8617 - val_loss: 0.4626
Epoch 5/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.7967 - loss: 0.6529



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 106ms/step - accuracy: 0.7967 - loss: 0.6529 - val_accuracy: 0.8837 - val_loss: 0.3889
Epoch 6/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 64ms/step - accuracy: 0.8313 - loss: 0.5533



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m173s[0m 85ms/step - accuracy: 0.8313 - loss: 0.5533 - val_accuracy: 0.8844 - val_loss: 0.3708
Epoch 7/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step - accuracy: 0.8433 - loss: 0.4988



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 83ms/step - accuracy: 0.8433 - loss: 0.4988 - val_accuracy: 0.8972 - val_loss: 0.3324
Epoch 8/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 65ms/step - accuracy: 0.8625 - loss: 0.4431



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m141s[0m 83ms/step - accuracy: 0.8625 - loss: 0.4431 - val_accuracy: 0.8979 - val_loss: 0.3323
Epoch 9/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 72ms/step - accuracy: 0.8758 - loss: 0.3894



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m127s[0m 94ms/step - accuracy: 0.8758 - loss: 0.3894 - val_accuracy: 0.9068 - val_loss: 0.2957
Epoch 10/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.8900 - loss: 0.3484



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m119s[0m 88ms/step - accuracy: 0.8900 - loss: 0.3484 - val_accuracy: 0.9081 - val_loss: 0.2955
Epoch 11/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.8932 - loss: 0.3301



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 89ms/step - accuracy: 0.8932 - loss: 0.3301 - val_accuracy: 0.9051 - val_loss: 0.2944
Epoch 12/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 69ms/step - accuracy: 0.9062 - loss: 0.2882



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m125s[0m 92ms/step - accuracy: 0.9062 - loss: 0.2882 - val_accuracy: 0.9151 - val_loss: 0.2656
Epoch 13/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 72ms/step - accuracy: 0.9164 - loss: 0.2569



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m147s[0m 96ms/step - accuracy: 0.9164 - loss: 0.2569 - val_accuracy: 0.9142 - val_loss: 0.2691
Epoch 14/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.9251 - loss: 0.2379



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 95ms/step - accuracy: 0.9251 - loss: 0.2379 - val_accuracy: 0.9174 - val_loss: 0.2631
Epoch 15/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.9309 - loss: 0.2122



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m135s[0m 99ms/step - accuracy: 0.9309 - loss: 0.2122 - val_accuracy: 0.9126 - val_loss: 0.2745
Epoch 16/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step - accuracy: 0.9345 - loss: 0.1999



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m118s[0m 87ms/step - accuracy: 0.9345 - loss: 0.1999 - val_accuracy: 0.9147 - val_loss: 0.2682
Epoch 17/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step - accuracy: 0.9415 - loss: 0.1840



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m137s[0m 84ms/step - accuracy: 0.9415 - loss: 0.1840 - val_accuracy: 0.9194 - val_loss: 0.2555
Epoch 18/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 70ms/step - accuracy: 0.9439 - loss: 0.1683



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 95ms/step - accuracy: 0.9439 - loss: 0.1683 - val_accuracy: 0.9113 - val_loss: 0.2883
Epoch 19/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step - accuracy: 0.9466 - loss: 0.1586



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m135s[0m 90ms/step - accuracy: 0.9466 - loss: 0.1586 - val_accuracy: 0.9192 - val_loss: 0.2629
Epoch 20/100
[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 73ms/step - accuracy: 0.9518 - loss: 0.1486



[1m1358/1358[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m132s[0m 97ms/step - accuracy: 0.9518 - loss: 0.1486 - val_accuracy: 0.9213 - val_loss: 0.2591
API is running at: NgrokTunnel: "https://8964-34-125-217-213.ngrok-free.app" -> "http://localhost:8000"


INFO:     Started server process [550]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
INFO:     197.51.155.117:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step
INFO:     197.51.155.117:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
INFO:     197.51.155.117:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step
INFO:     197.51.155.117:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step
INFO:     197.51.155.117:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step
INFO:     197.51.155.117:0 - "POST /predict/ HTTP/1.1" 200 OK


In [6]:
!pip install fastapi uvicorn pyngrok nest_asyncio tensorflow pillow numpy python-multipart

from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import nest_asyncio
import uvicorn
from pyngrok import ngrok
import os
import random
import shutil
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

# Define constants
img_size = 256
batch_size = 32
MODEL_PATH = "/content/drive/MyDrive/Colab Notebooks/model_epoch_20.h5"

# Create and compile the model
def create_model(num_classes):
    model = models.Sequential([
        layers.Conv2D(32, (3,3), activation='relu', input_shape=[img_size, img_size, 3]),
        layers.MaxPooling2D(2,2),
        layers.Conv2D(64, (3,3), activation='relu'),
        layers.MaxPooling2D(2,2),
        layers.Flatten(),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.5),  # New dropout layer added to reduce overfitting
        layers.Dense(num_classes, activation='softmax')
    ])
    from tensorflow.keras.optimizers import Adam
    optimizer = Adam(learning_rate=0.0001)  # Lower learning rate for more stable training
    model.compile(optimizer=optimizer, metrics=['accuracy'], loss='categorical_crossentropy')
    return model


# Image preprocessing function
def load_and_preprocess(image_path):
    img = Image.open(image_path).convert('RGB').resize([img_size, img_size])
    img_array = np.array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255
    return img_array

def predict_image_class(model, image_path, class_indices):
    preprocess_img = load_and_preprocess(image_path)
    predictions = model.predict(preprocess_img)[0]
    predicted_class_index = np.argmax(predictions)
    return class_indices[predicted_class_index]


# Setup data generators
def setup_data_generators():
    datagen = ImageDataGenerator(
        rescale=1./255,
        validation_split=0.2,
        rotation_range=20,          # Randomly rotate images by up to 20 degrees
        width_shift_range=0.1,        # Randomly shift images horizontally by 10%
        height_shift_range=0.1,       # Randomly shift images vertically by 10%
        shear_range=0.1,              # Apply shearing transformations
        zoom_range=0.1,               # Randomly zoom in on images
        horizontal_flip=True          # Randomly flip images horizontally
    )
    dataset_path = "/content/plantvillage dataset/color"
    train_gen = datagen.flow_from_directory(
        dataset_path, target_size=(img_size, img_size), batch_size=batch_size,
        class_mode='categorical', subset='training'
    )
    validation_gen = datagen.flow_from_directory(
        dataset_path, target_size=(img_size, img_size), batch_size=batch_size,
        class_mode='categorical', subset='validation'
    )
    return train_gen, validation_gen


# Create FastAPI app with CORS enabled
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins (replace "*" with your web app's domain for production)
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Set this flag to True if you want to train the model for additional epochs in this session.
# Set it to False if you only want to load the saved model for the API.
should_train = False  # Change to True to train for additional epochs

if should_train:
# Setup data generators
    train_gen, validation_gen = setup_data_generators()

    # If a saved model exists, load it. Otherwise, create a new one.
    if os.path.exists(MODEL_PATH):
        print("Loading saved model for training...")
        model = tf.keras.models.load_model(MODEL_PATH)
    else:
        print("Training model from scratch...")
        model = create_model(train_gen.num_classes)

    # Define number of epochs
    epochs = 100

    # Setup callbacks
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "/content/drive/MyDrive/Colab Notebooks/model_epoch_{epoch:02d}.h5",
    save_best_only=False,
    save_weights_only=False
)
    early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

    try:
        # Train the model
        model.fit(
            train_gen,
            steps_per_epoch=train_gen.samples // batch_size,
            epochs=epochs,
            validation_data=validation_gen,
            validation_steps=validation_gen.samples // batch_size,
            callbacks=[checkpoint_cb, early_stop]
        )
    except KeyboardInterrupt:
        print("Training interrupted manually. Saving model...")
        model.save(MODEL_PATH)
        print("Model saved.")
else:
    # Setup data generators for inference
    train_gen, _ = setup_data_generators()

    # If not training, just load the model for API usage.
    if os.path.exists(MODEL_PATH):
        print("Loading saved model for API usage...")
        model = tf.keras.models.load_model(MODEL_PATH)
    else:
        print("Model not found. Training model from scratch (only for API usage)...")
        model = create_model(train_gen.num_classes)


# Create class indices for mapping predictions
class_indices = {v: k for k, v in train_gen.class_indices.items()}

# Define the prediction endpoint
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
    try:
        temp_dir = "temp_uploads"
        os.makedirs(temp_dir, exist_ok=True)
        temp_path = os.path.join(temp_dir, file.filename)
        with open(temp_path, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
        predict_name = predict_image_class(model, temp_path, class_indices)
        data = predict_name.split("___")
        info = {
            "name": " ".join(data[0].split("_")),
            "condition": " ".join(data[1].split("_"))
        }
        return info
    except Exception as e:
        return JSONResponse(status_code=500, content={"message": f"Error processing image: {str(e)}"})


# Start the server with ngrok
nest_asyncio.apply()
ngrok.set_auth_token("2yMMFAjAP2weUpE1Xj4cg5WHaBj_2SqYRUmg5rW1vYmWt1CL3")
public_url = ngrok.connect(8000)
print(f"API is running at: {public_url}")
uvicorn.run(app, host="0.0.0.0", port=8000)

Found 43456 images belonging to 38 classes.
Found 10849 images belonging to 38 classes.
Loading saved model for API usage...




API is running at: NgrokTunnel: "https://8485-34-106-71-238.ngrok-free.app" -> "http://localhost:8000"


INFO:     Started server process [463]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 142ms/step
INFO:     197.51.16.212:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step
INFO:     197.51.16.212:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step
INFO:     197.51.16.212:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 78ms/step
INFO:     197.51.16.212:0 - "POST /predict/ HTTP/1.1" 200 OK
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 116ms/step
INFO:     197.51.16.212:0 - "POST /predict/ HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [463]
