# YOLOv8 Plant Object Detection (TensorFlow Export)

This notebook trains a YOLOv8 model to detect plant objects using a dataset fetched from an API endpoint. Training runs with Ultralytics (PyTorch backend) and the trained model is exported to TensorFlow SavedModel for TF-based inference.

Workflow:
- Install dependencies (Ultralytics, TensorFlow, etc.)
- Configure API endpoint and local paths
- Download and prepare dataset in YOLO format
- Train YOLOv8 model
- Validate/test performance
- Run inference on a user-provided image
- Export to TensorFlow SavedModel and run a basic TF inference demo

Notes:
- The API is expected to return a ZIP containing images and labels. If labels are not in YOLO format, the helper will attempt a simple split and use provided labels if available; otherwise you may need to adapt the preparation cell.
- Set the configuration variables in the next cell to match your API and environment.


In [None]:
pip install -q --upgrade pip
pip install -q ultralytics==8.3.29 tensorflow==2.16.1 onnx==1.16.1 onnxruntime==1.18.1 tf2onnx==1.16.1 opencv-python matplotlib tqdm requests pyyaml supervision


In [None]:
import sys, platform, os, shutil, zipfile, random, json, time
import pathlib
from pathlib import Path

print(f"Python: {sys.version}")
print(f"Platform: {platform.platform()}")

try:
    import torch
    print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA device count: {torch.cuda.device_count()}")
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
except Exception as e:
    print("PyTorch not installed yet or no CUDA available.")

try:
    import tensorflow as tf
    print(f"TensorFlow: {tf.__version__}")
    print(f"TF GPU: {len(tf.config.list_physical_devices('GPU'))>0}")
except Exception as e:
    print("TensorFlow not installed yet.")

from ultralytics import YOLO
import cv2
import numpy as np
import matplotlib.pyplot as plt

# Ensure base dirs
BASE_DIR = Path.cwd()
DATA_DIR = BASE_DIR / 'data' / 'plant'
RAW_DIR = DATA_DIR / 'raw'
YOLO_DIR = DATA_DIR / 'yolo'
for d in [DATA_DIR, RAW_DIR, YOLO_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print('Dirs ready:', DATA_DIR, RAW_DIR, YOLO_DIR)


In [None]:
# Configuration
API_DATA_URL = os.environ.get('PLANT_API_ZIP_URL', 'https://your-api.example.com/plant-dataset.zip')
API_HEADERS_JSON = os.environ.get('PLANT_API_HEADERS_JSON', '{}')  # e.g. '{"Authorization": "Bearer <TOKEN>"}'
DATASET_NAME = 'plant'
CLASS_NAMES = ['plant']  # Update if multiple classes
VAL_SPLIT = 0.2
RANDOM_SEED = 42
MODEL_VARIANT = 'yolov8n.pt'  # choose from yolov8n/s/m/l/x depending on GPU
EPOCHS = 50
IMG_SIZE = 640
BATCH = 16

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

print('Using API:', API_DATA_URL)
print('Headers:', API_HEADERS_JSON)


In [None]:
# Download dataset ZIP from API endpoint
import requests

headers = {}
try:
    headers = json.loads(API_HEADERS_JSON)
except Exception as e:
    print('Invalid headers JSON, using empty headers')

zip_path = RAW_DIR / 'dataset.zip'
if not zip_path.exists():
    print('Downloading dataset...')
    with requests.get(API_DATA_URL, headers=headers, stream=True, timeout=120) as r:
        r.raise_for_status()
        with open(zip_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
    print('Downloaded to', zip_path)
else:
    print('ZIP already exists at', zip_path)

# Extract
extract_dir = RAW_DIR / 'extracted'
if extract_dir.exists():
    shutil.rmtree(extract_dir)
extract_dir.mkdir(parents=True, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zf:
    zf.extractall(extract_dir)

print('Extracted to', extract_dir)

# Inspect extracted structure
for root, dirs, files in os.walk(extract_dir):
    print(root, 'dirs:', len(dirs), 'files:', len(files))
    break


In [None]:
# Prepare YOLO dataset structure: data/yolo/{images,labels}/{train,val}

image_exts = {'.jpg', '.jpeg', '.png', '.bmp'}
label_ext = '.txt'

images_all = []
labels_all = {}

for root, dirs, files in os.walk(extract_dir):
    for f in files:
        fp = Path(root) / f
        ext = fp.suffix.lower()
        if ext in image_exts:
            images_all.append(fp)
        elif ext == label_ext:
            labels_all[fp.stem] = fp

print(f'Total images found: {len(images_all)}')

# Simple split
random.shuffle(images_all)
num_val = max(1, int(len(images_all) * VAL_SPLIT))
val_images = set(p.stem for p in images_all[:num_val])

for split in ['train', 'val']:
    for sub in ['images', 'labels']:
        (YOLO_DIR / sub / split).mkdir(parents=True, exist_ok=True)

missing_labels = 0
for img_path in images_all:
    stem = img_path.stem
    split = 'val' if stem in val_images else 'train'
    dst_img = YOLO_DIR / 'images' / split / img_path.name
    shutil.copy2(img_path, dst_img)

    label_src = labels_all.get(stem)
    dst_lbl = YOLO_DIR / 'labels' / split / f'{stem}.txt'
    if label_src and label_src.exists():
        shutil.copy2(label_src, dst_lbl)
    else:
        # If no label, create empty label file (treat as background)
        missing_labels += 1
        open(dst_lbl, 'w').close()

print('Missing labels:', missing_labels)

# Write data.yaml
import yaml

data_yaml = {
    'path': str(YOLO_DIR),
    'train': 'images/train',
    'val': 'images/val',
    'names': {i: name for i, name in enumerate(CLASS_NAMES)},
}

yaml_path = YOLO_DIR / 'data.yaml'
with open(yaml_path, 'w') as f:
    yaml.safe_dump(data_yaml, f)

print('Wrote', yaml_path)
print(yaml.safe_dump(data_yaml))


In [None]:
# Train YOLOv8 model

model = YOLO(MODEL_VARIANT)

results = model.train(
    data=str((YOLO_DIR / 'data.yaml').resolve()),
    epochs=EPOCHS,
    imgsz=IMG_SIZE,
    batch=BATCH,
    project=str((BASE_DIR / 'runs').resolve()),
    name='plant_yolov8',
)

print('Training complete. Best weights at:', model.best)


In [None]:
# Validate / test

metrics = model.val()
print(metrics)  # includes mAP50-95, precision, recall


In [None]:
# Inference on a user-provided image

USER_IMAGE_PATH = os.environ.get('PLANT_TEST_IMAGE', '')  # or set a path here
if not USER_IMAGE_PATH:
    print('Set PLANT_TEST_IMAGE env var to an image path to run inference.')
else:
    infer_model = YOLO(model.best if hasattr(model, 'best') and model.best else model)
    res = infer_model.predict(source=USER_IMAGE_PATH, imgsz=IMG_SIZE, conf=0.25)
    for r in res:
        im = r.plot()  # BGR image
        save_path = BASE_DIR / 'runs' / 'detect' / 'plant_infer'
        save_path.mkdir(parents=True, exist_ok=True)
        out_file = save_path / f"result_{Path(USER_IMAGE_PATH).name}"
        cv2.imwrite(str(out_file), im)
        print('Saved result to', out_file)

    # Show inline
    bgr = cv2.imread(USER_IMAGE_PATH)
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(8,6)); plt.imshow(rgb); plt.axis('off'); plt.title('Input')
    plt.show()

    rgb_out = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    plt.figure(figsize=(8,6)); plt.imshow(rgb_out); plt.axis('off'); plt.title('Detections')
    plt.show()


In [None]:
# Export to TensorFlow SavedModel and basic TF inference demo

# Export - Ultralytics supports export to multiple formats; we'll use tf SavedModel
export_dir = BASE_DIR / 'exports'
export_dir.mkdir(parents=True, exist_ok=True)

export_results = model.export(format='tf', imgsz=IMG_SIZE, keras=False)  # creates folder next to weights
print('Exported to:', export_results)

# Locate SavedModel directory
saved_model_dir = None
if isinstance(export_results, (str, Path)):
    saved_model_dir = Path(export_results)
else:
    # If export returns a dict-like, attempt to find the tf path
    try:
        saved_model_dir = Path(export_results.get('tf'))
    except Exception:
        pass

print('SavedModel at:', saved_model_dir)

# Basic TF inference using onnxruntime or tf.saved_model (depending on export)
try:
    import tensorflow as tf
    imported = tf.saved_model.load(str(saved_model_dir))
    infer = imported.signatures.get('serving_default')
    print('Loaded TF SavedModel')

    # Prepare image
    if USER_IMAGE_PATH:
        img = cv2.imread(USER_IMAGE_PATH)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_resized = cv2.resize(img_rgb, (IMG_SIZE, IMG_SIZE))
        x = img_resized.astype('float32') / 255.0
        x = np.expand_dims(x, axis=0)

        out = infer(tf.constant(x))
        print('TF inference outputs:', list(out.keys()))
    else:
        print('Skip TF demo: set PLANT_TEST_IMAGE to run.')
except Exception as e:
    print('TF export/inference demo failed:', e)
