An improved weed detection model based on RetinaNet. This repository makes available the source code and public dataset for the work, "WeedNet-R:A sugar beet field weed detection algorithm based on enhanced RetinaNet and context semantic fusion"
-
Clone this repository
-
Install environment
cd code pip install -r requirements.txt
- download sugarbeets2016 dataset from dataset extraction code:zr06
- Unpack the dataset to your path
- run
tocsv.pyunder./datasetdirectory - Copy the generated train.csv, test.csv, and val.csv files to
./code/dataset/
python train.py --dataset csv --csv_train <path/to/train_annots.csv> --csv_classes <path/to/train/class_list.csv> --csv_val <path/to/val_annots.csv>
A pre-trained model is available at: WeedNet-R pretrain model psw:k3xf
run the following script to validate:
python csv_validation.py --csv_annotations_path ./dataset/test.csv --model_path path/to/model.pt --images_path path/to/images_dir --class_list_path path/to/class_list.csv (optional) iou_threshold iou_thres (0<iou_thresh<1)
This will visualize bounding boxes on the validation set. To visualise with a CSV dataset, use:
python visualize.py --dataset csv --csv_classes <path/to/train/class_list.csv> --csv_val <path/to/val_annots.csv> --model <path/to/model.pt>
The RetinaNet model uses a resnet backbone (download link psw:v4o1) You can set the depth of the Resnet model using the --depth argument. Depth must be one of 18, 34, 50, 101 or 152. Note that deeper models are more accurate but are slower and use more memory.
The original weed dataset source form SugarBeets2016
The base network RetinaNet from pytorch-retinanet