# Insect Classification Pipeline with Bplusplus

This notebook demonstrates a complete pipeline for detecting and classifying insects in videos using the Bplusplus library.

## What you'll accomplish:
1. **Collect** insect images from GBIF database
2. **Prepare** data by detecting and cropping insects from raw images  
3. **Train** a classification model for family, genus, and species identification
4. **Test** the trained model's performance
5. **Run inference** on videos to detect, classify, and track insects over time

## How it works:
- Uses a **two-stage approach**: pre-trained detection + custom classification
- Detection stage: Locates insects in images/video frames
- Classification stage: Identifies insects at three taxonomic levels (family → genus → species)
- Tracking stage: Follows insects over time and aggregates predictions per individual

By the end, you'll have a system that can process videos and output the most likely classification for each tracked insect.



## Setup: Create Virtual Environment (Recommended)

Create an isolated environment to avoid package conflicts:

```bash
python3 -m venv bplusplus_env
source bplusplus_env/bin/activate
```

## Setup: Install Required Packages

In [None]:
! pip install bplusplus

## Import required packages

In [1]:
import bplusplus
from typing import Any
from pathlib import Path
import requests
from tqdm import tqdm

## Set directories

In [2]:
MAIN_DIR = Path("./")

GBIF_DATA_DIR = MAIN_DIR / "GBIF_data"
PREPARED_DATA_DIR = MAIN_DIR / "prepared_data"
TRAINED_MODEL_DIR = MAIN_DIR / "trained_model"

## Step 1: Collect Insect Images from GBIF

We download images from the GBIF (Global Biodiversity Information Facility) database for our target species.

**Important notes:**
- Download more images than needed - many will be filtered out during preparation
- Internet connection may be unstable - monitor progress and resume if needed
- Check the `GBIF_DATA_DIR` folder to track downloaded files

In [3]:
names = [
        "Coccinella septempunctata", "Apis mellifera", "Bombus lapidarius"
    ]

search: dict[str, Any] = {
    "scientificName": names
}

In [None]:
bplusplus.collect(
    group_by_key=bplusplus.Group.scientificName,
    search_parameters=search, 
    images_per_group=100,
    output_directory=GBIF_DATA_DIR,
    num_threads=3
)


## Step 2: Prepare Data for Training

This step uses a pre-trained vision model to:
1. **Detect** insects in the raw GBIF images
2. **Crop** each detected insect to focus on the subject
3. **Resize** images to a consistent size for training

**What to expect:**
- Many images will be rejected (low success rate is normal)
- Only clear, well-detected insects proceed to training
- This filtering ensures high-quality training data

In [None]:
bplusplus.prepare(
    input_directory=GBIF_DATA_DIR,
    output_directory=PREPARED_DATA_DIR,
    img_size=60
)

## Step 3: Train the Classification Model

Train a ResNet-based neural network that classifies insects at three taxonomic levels:
- **Family** (e.g., Coccinellidae for ladybugs)
- **Genus** (e.g., Coccinella)  
- **Species** (e.g., Coccinella septempunctata)

The model learns hierarchical relationships between these classification levels. 

In [None]:
bplusplus.train(
    batch_size=4,
    epochs=9,
    patience=3,
    img_size=60,
    data_dir=PREPARED_DATA_DIR,
    output_dir=TRAINED_MODEL_DIR,
    species_list=names
)


## Step 4: Download Pre-trained Detection Weights

For the two-stage pipeline, we need pre-trained weights for insect detection.

**Option 1:** Manual download from:
https://github.com/Tvenver/Bplusplus/releases/download/v1.2.3/v11small-generic.pt

**Option 2:** Automatic download using the code below:


In [9]:
def __download_file_from_github_release(url, dest_path):

    """
    Downloads a file from a given GitHub release URL and saves it to the specified destination path,
    with a progress bar displayed in the terminal.

    Args:
        url (str): The URL of the file to download.
        dest_path (Path): The destination path where the file will be saved.

    Raises:
        Exception: If the file download fails.
    """

    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)

    if response.status_code == 200:
        with open(dest_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=block_size):
                progress_bar.update(len(chunk))
                f.write(chunk)
        progress_bar.close()
    else:
        progress_bar.close()
        raise Exception(f"Failed to download file from {url}")

In [4]:
YOLO_WEIGHTS = TRAINED_MODEL_DIR / "v11small-generic.pt"

In [None]:
github_release_url = 'https://github.com/Tvenver/Bplusplus/releases/download/v1.2.3/v11small-generic.pt'

if not YOLO_WEIGHTS.exists():
    __download_file_from_github_release(github_release_url, YOLO_WEIGHTS)

## Step 5: Test the Trained Model

Evaluate your trained model on a test dataset to measure accuracy.

**Test data requirements:**
- Directory structure: `test_data/images/` and `test_data/labels/`
- Images: Standard image formats (jpg, png, etc.)
- Labels: YOLO format text files with bounding boxes
- Label format: `<class> <x_center> <y_center> <width> <height>`
- Class names must match your training species order

**Note:** This tests detection + classification accuracy, but does not include tracking.


In [5]:
TEST_DATA_DIR = MAIN_DIR / "test_data" #if you want to test the model on a different dataset (two stage)
RESNET_MULTITASK_WEIGHTS = TRAINED_MODEL_DIR / "best_multitask.pt"

In [None]:
bplusplus.test(
    species_list=names,
    test_set=TEST_DATA_DIR,
    yolo_weights=YOLO_WEIGHTS,
    hierarchical_weights=RESNET_MULTITASK_WEIGHTS,
    output_dir=TRAINED_MODEL_DIR
)


## Step 6: Run Video Inference

Process a video to detect, classify, and track insects over time.

**How it works:**
1. **Detection:** Finds insects in each video frame
2. **Classification:** Identifies each detected insect 
3. **Tracking:** Follows individual insects across frames
4. **Aggregation:** Combines predictions for each tracked insect

**Key parameters:**
- `fps`: Video processing frame rate (None = use original)
- `tracker_max_frames`: How long to remember lost insects
  - Example: 60 frames at 15 fps = 4 seconds before forgetting an insect

In [6]:
VIDEO_INPUT_PATH = MAIN_DIR / 'videos' / "test_video.mp4"
VIDEO_OUTPUT_PATH = MAIN_DIR / 'videos' / "test_video_output.mp4"

In [None]:
bplusplus.inference(
    species_list=names,
    yolo_model_path=YOLO_WEIGHTS,
    hierarchical_model_path=RESNET_MULTITASK_WEIGHTS,
    confidence_threshold=0.35,
    video_path=VIDEO_INPUT_PATH,
    output_path=VIDEO_OUTPUT_PATH,
    tracker_max_frames=200,
    fps=None
    )