### Train an image classification model

#### 1. Necessary Imports

In [1]:
# Set env variable for ArcGIS to enable TensorFlow backend
%env ARCGIS_ENABLE_TF_BACKEND=1

env: ARCGIS_ENABLE_TF_BACKEND=1


In [2]:
import os 
from pathlib import Path

from arcgis.learn import prepare_data, FeatureClassifier

#### 2. Set Dataset Path

In [None]:
filepath = r''

In [None]:
data_path = Path(os.path.join(os.path.splitext(filepath)[0]))

#### 3. Filter out non RGB Images

In [3]:
from glob import glob
from PIL import Image

In [None]:
for image_filepath in glob(os.path.join(data_path, 'images', '**','*.jpg')):
    if Image.open(image_filepath).mode != 'RGB':
        os.remove(image_filepath)

#### 4. Prepare data

In [None]:
data = prepare_data(
    path=data_path,
    dataset_type='Imagenet',
    batch_size=64,
    chip_size=300
)

#### 5. Visualize a few samples from your training data

In [None]:
data.show_batch(rows=2)

#### 6. Load model architecture

In [None]:
model = FeatureClassifier(data, backbone='MobileNetV2', backend='tensorflow')

#### 7. Find an optimal learning rate

In [None]:
lr = model.lr_find()

#### 8. Fit the model

In [None]:
model.fit(25, lr=lr)

#### 9. Visualize results in validation set

In [None]:
model.show_results(rows=4, thresh=0.2)

#### 10. Save the model

In [None]:
model.save('Plant-identification-25-tflite', framework="tflite")