In [1]:
## Adjust the paths in your environment 

import sys
print(sys.path)

sys.path.append('/hy-tmp/yolo_soybean/ultralytics')

['/hy-tmp/yolo_soybean_sam/ultralytics', '/usr/local/miniconda3/lib/python38.zip', '/usr/local/miniconda3/lib/python3.8', '/usr/local/miniconda3/lib/python3.8/lib-dynload', '', '/usr/local/miniconda3/lib/python3.8/site-packages']


# Train 

### Step1-train the standard yolo

In [3]:
from ultralytics import YOLO

device = 'cuda'

# Load a model
model = YOLO("yolov8m.yaml")  # build a new model from YAML
model = YOLO("yolov8m.pt")  # load a pretrained model (recommended for training)
model = model.to(device)

In [None]:
# Train the model
results = model.train(data="/hy-tmp/yolo_soybean/datasets/Normal_dataset/normal.yaml", epochs=300, imgsz=640, batch=16)

In [6]:
def load_labels(label_path):
    """Load label file and count the number of beans based on class ID."""
    count_map = {0: 1, 1: 2, 2: 3, 3: 4}  
    total_count = 0
    with open(label_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            class_id = int(line.split()[0])
            total_count += count_map[class_id]
    return total_count

def count_beans(predictions):
    """Count the number of beans in predictions based on class ID."""
    count_map = {0: 1, 1: 2, 2: 3, 3: 4}  # 类别 ID 映射到豆粒数量
    total_count = 0
    centers = []
    for prediction in predictions:
        x1, y1, x2, y2, conf, class_id = prediction
        class_id = int(class_id)
        total_count += count_map[class_id]
        # Calculate the center point
        center_x = (x1 + x2) / 2
        center_y = (y1 + y2) / 2
        centers.append([center_x, center_y])
    return total_count, centers

# import sys
# sys.path.append("./HQSAM/")
from SAMHQ.segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "/hy-tmp/yolo_soybean_sam/ckpts/sam_hq_vit_h.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamPredictor(sam)



  from .autonotebook import tqdm as notebook_tqdm


<All keys matched successfully>


In [None]:
## save images and re-train

In [10]:
import cv2
import PIL.Image as Image
import glob
import os
import numpy as np

test_images_path = '/hy-tmp/yolo_soybean/datasets/Normal_dataset/train/images'
labels_path = '/hy-tmp/yolo_soybean/datasets/Normal_dataset/train/labels'
save_path = '/hy-tmp/yolo_soybean/datasets/Normal_dataset/train_sam/images'

image_paths = glob.glob(os.path.join(test_images_path, '*.*'))

    
errors = []

for image_path in image_paths:
    
    image_name = os.path.basename(image_path)
    label_name = os.path.splitext(image_name)[0] + '.txt'
    label_path = os.path.join(labels_path, label_name)

    if not os.path.exists(label_path):
        print(f'Label file not found for image {image_name}, skipping...')
        continue

    true_count = load_labels(label_path)

    results = model(image_path)

    if results and len(results) > 0:
            predictions = results[0].boxes.data  
            pred_count, centers = count_beans(predictions.cpu().numpy())  
    else:
        pred_count = 0

    image_ori = cv2.imread(image_path)
    image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB)
    predictor.set_image(image_ori)

    input_label = np.array([1]*len(centers))
    centers = np.array(centers)

    masks, scores, logits = predictor.predict(
        point_coords=centers,
        point_labels=input_label,
        multimask_output=True,
    )
    mask = masks[0]
    extracted = np.zeros_like(image_ori)
    extracted[mask] = image_ori[mask]
    
    extracted_image = Image.fromarray(extracted)

    re_image_name = os.path.splitext(image_name)[0] + '.png'
    re_image_path = os.path.join(save_path, re_image_name)
    extracted_image.save(re_image_path)
    



image 1/1 /hy-tmp/yolo_soybean_sam/dataset/outdoor/train/images/D12E-D13E_DSC04509_a_1_D13E_a_1.png: 640x352 10 1spps, 45 2spps, 35 3spps, 9.9ms
Speed: 7.4ms preprocess, 9.9ms inference, 1.4ms postprocess per image at shape (1, 3, 640, 352)

image 1/1 /hy-tmp/yolo_soybean_sam/dataset/outdoor/train/images/D12E-D13E_DSC04509_a_2_D13E_a_2.png: 640x224 14 1spps, 76 2spps, 31 3spps, 10.7ms
Speed: 1.2ms preprocess, 10.7ms inference, 1.2ms postprocess per image at shape (1, 3, 640, 224)

image 1/1 /hy-tmp/yolo_soybean_sam/dataset/outdoor/train/images/D12E-D13E_DSC04509_a_3_D13E_a_3.png: 640x160 10 1spps, 55 2spps, 26 3spps, 10.6ms
Speed: 1.0ms preprocess, 10.6ms inference, 1.2ms postprocess per image at shape (1, 3, 640, 160)

image 1/1 /hy-tmp/yolo_soybean_sam/dataset/outdoor/train/images/D12E-D13E_DSC04516_a_6_D12E_a_1.png: 640x224 12 1spps, 72 2spps, 24 3spps, 10.3ms
Speed: 1.2ms preprocess, 10.3ms inference, 1.2ms postprocess per image at shape (1, 3, 640, 224)

image 1/1 /hy-tmp/yolo_so

### Step2-Retrain yolo

In [1]:
import sys
print(sys.path)

sys.path.append('/hy-tmp/yolo_soybean/ultralytics')

['/hy-tmp/yolo_soybean_sam/ultralytics', '/usr/local/miniconda3/lib/python38.zip', '/usr/local/miniconda3/lib/python3.8', '/usr/local/miniconda3/lib/python3.8/lib-dynload', '', '/usr/local/miniconda3/lib/python3.8/site-packages']


In [1]:
from ultralytics import YOLO

# Load a model
model = YOLO("yolov8m.yaml")  # build a new model from YAML
model = YOLO("yolo_sam.pt")  # load a pretrained model (recommended for training)

# Train the model
results = model.train(data="/hy-tmp/yolo_soybean/datasets/Normal_dataset/normal.yaml", epochs=300, imgsz=640, batch=16)

## Evaluation and inference

In [None]:
from ultralytics import YOLO

device = 'cuda'

# Load a model
model = YOLO("yolov8m.yaml")  # build a new model from YAML
model = YOLO("yolo_sam.pt")  # load a pretrained model (recommended for training)
model = model.to(device)

In [None]:
# import sys
# sys.path.append("./HQSAM/")
from SAMHQ.segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "/hy-tmp/yolo_soybean/ckpts/sam_hq_vit_h.pth"

model_type = "vit_h"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# mask_generator = SamAutomaticMaskGenerator(sam)
predictor = SamPredictor(sam)

### Predict seeds

In [10]:
def load_labels(label_path):
    """Load label file and count the number of beans based on class ID."""
    count_map = {0: 1, 1: 2, 2: 3, 3: 4}  
    total_count = 0
    with open(label_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            class_id = int(line.split()[0])
            total_count += count_map[class_id]
    return total_count

def count_beans(predictions):
    """Count the number of beans in predictions based on class ID."""
    count_map = {0: 1, 1: 2, 2: 3, 3: 4}  # 类别 ID 映射到豆粒数量
    total_count = 0
    centers = []
    for prediction in predictions:
        x1, y1, x2, y2, conf, class_id = prediction
        class_id = int(class_id)
        total_count += count_map[class_id]
        # Calculate the center point
        center_x = (x1 + x2) / 2
        center_y = (y1 + y2) / 2
        centers.append([center_x, center_y])
    return total_count, centers


  from .autonotebook import tqdm as notebook_tqdm


<All keys matched successfully>


In [2]:
import cv2
import PIL.Image as Image
import glob
import os
import numpy as np

test_images_path = '/hy-tmp/yolo_soybean/datasets/Normal_dataset/eval/images'
labels_path = '/hy-tmp/yolo_soybean/datasets/Normal_dataset/eval/labels'

image_paths = glob.glob(os.path.join(test_images_path, '*.*'))

    
errors = []

for image_path in image_paths:
    
    image_name = os.path.basename(image_path)
    label_name = os.path.splitext(image_name)[0] + '.txt'
    label_path = os.path.join(labels_path, label_name)

    if not os.path.exists(label_path):
        print(f'Label file not found for image {image_name}, skipping...')
        continue

    true_count = load_labels(label_path)

    results = model(image_path)

    if results and len(results) > 0:
            predictions = results[0].boxes.data  
            pred_count, centers = count_beans(predictions.cpu().numpy())  
    else:
        pred_count = 0

    image_ori = cv2.imread(image_path)
    image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB)
    predictor.set_image(image_ori)

    input_label = np.array([1]*len(centers))
    centers = np.array(centers)

    masks, scores, logits = predictor.predict(
        point_coords=centers,
        point_labels=input_label,
        multimask_output=True,
    )
    mask = masks[0]
    extracted = np.zeros_like(image_ori)
    extracted[mask] = image_ori[mask]
    
    extracted_image = Image.fromarray(extracted)

    results = model(extracted_image)

    if results and len(results) > 0:
            predictions = results[0].boxes.data  
            pred_count, centers = count_beans(predictions.cpu().numpy())  
    else:
        pred_count = 0
    

    error = abs(pred_count - true_count)
    errors.append(error)

    print(f'Image: {image_name}, True Count: {true_count}, Predicted Count: {pred_count}, Error: {error}')

mae = np.mean(errors)
print(f'Mean Absolute Error (MAE): {mae}')


### Predict pods

In [None]:
def load_labels(label_path):
    """Load label file and count the number of beans (detection boxes)."""
    with open(label_path, 'r') as f:
        lines = f.readlines()
    return len(lines)

def count_beans(predictions):
    """Count the number of beans in predictions based on class ID."""
    centers = []
    for prediction in predictions:
        x1, y1, x2, y2, conf, class_id = prediction
        class_id = int(class_id)
        # Calculate the center point
        center_x = (x1 + x2) / 2
        center_y = (y1 + y2) / 2
        centers.append([center_x, center_y])
    return len(predictions), centers


In [None]:
import cv2
import PIL.Image as Image
import glob
import os
import numpy as np

test_images_path = '/hy-tmp/yolo_soybean/datasets/Normal_dataset/eval/images'
labels_path = '/hy-tmp/yolo_soybean/datasets/Normal_dataset/eval/labels'

image_paths = glob.glob(os.path.join(test_images_path, '*.*'))

    
errors = []

for image_path in image_paths:
    
    image_name = os.path.basename(image_path)
    label_name = os.path.splitext(image_name)[0] + '.txt'
    label_path = os.path.join(labels_path, label_name)

    if not os.path.exists(label_path):
        print(f'Label file not found for image {image_name}, skipping...')
        continue

    true_count = load_labels(label_path)

    results = model(image_path)

    if results and len(results) > 0:
            predictions = results[0].boxes.data  
            pred_count, centers = count_beans(predictions.cpu().numpy())  
    else:
        pred_count = 0

    image_ori = cv2.imread(image_path)
    image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB)
    predictor.set_image(image_ori)

    input_label = np.array([1]*len(centers))
    centers = np.array(centers)

    masks, scores, logits = predictor.predict(
        point_coords=centers,
        point_labels=input_label,
        multimask_output=True,
    )
    mask = masks[0]
    extracted = np.zeros_like(image_ori)
    extracted[mask] = image_ori[mask]
    
    extracted_image = Image.fromarray(extracted)

    results = model(extracted_image)

    if results and len(results) > 0:
            predictions = results[0].boxes.data  
            pred_count, centers = count_beans(predictions.cpu().numpy())  
    else:
        pred_count = 0
    

    error = abs(pred_count - true_count)
    errors.append(error)

    print(f'Image: {image_name}, True Count: {true_count}, Predicted Count: {pred_count}, Error: {error}')

mae = np.mean(errors)
print(f'Mean Absolute Error (MAE): {mae}')
