In [1]:
from PIL import Image
from ultralytics import YOLO
import os, yaml, json, requests
import matplotlib.pyplot as plt
from sahi import AutoDetectionModel
from sahi.predict import get_prediction, get_sliced_prediction, predict
import tqdm, random
import time
import torch
import torchsummary

## SAHI

In [2]:
%matplotlib inline
def show_single_img(img: Image):
    plt.cla()
    plt.imshow(img)
    plt.axis("off")
    plt.show()

In [None]:
img_path = "./example.jpg"
img = Image.open(img_path)
show_single_img(img)

In [6]:
detection_model = AutoDetectionModel.from_pretrained(
    model_type='yolov8',
    model_path="./models/best.pt",
    confidence_threshold=0.3,
    device="cpu",
)
result = get_sliced_prediction(
    img,
    detection_model,
    slice_height=640,
    slice_width=640,
    overlap_height_ratio=0.2,
    overlap_width_ratio=0.2,
    verbose=0
)

In [None]:
res_list = result.object_prediction_list
print(len(res_list))

In [8]:
result.export_visuals("./")

## prediction

In [6]:
class Predictor():

    def __init__(self, model_path: str, conf=0.5):
        self.detection_model = AutoDetectionModel.from_pretrained(
            model_type='yolov8',
            model_path=model_path,
            confidence_threshold=conf,
            device="cpu",
        )
        self.model_name = model_path.split("/")[-1][:-3]
        self.img_path = "./dataset/images/"
        self.lab_path = "./dataset/labels/"

    def get_batches(self, num=10):
        img_all = os.listdir(self.img_path)
        batch_size = len(img_all) // num
        print("num_images=", len(img_all))
        print("batch_size=", batch_size)
        print("batch_num=", num)
        random.shuffle(img_all)
        for i in range(num):
            yield img_all[i * batch_size:min(len(img_all), (i + 1) * batch_size)]

    def predict_batch(self, img_list):
        real = 0
        pred = 0
        # qbar = tqdm.tqdm(img_list)
        # for img_name in qbar:
        for img_name in img_list:
            # qbar.set_description(f"real={real}, pred={pred}")
            img_path_full = os.path.join(self.img_path, img_name)
            result = get_sliced_prediction(
                img_path_full,
                self.detection_model,
                slice_height=384,
                slice_width=384,
                overlap_height_ratio=0.2,
                overlap_width_ratio=0.2,
                verbose=0)
            pred += len(result.object_prediction_list)

            lab_path_full = os.path.join(self.lab_path, img_name.replace(".jpg", ".txt"))
            with open(lab_path_full, "r") as f:
                real += len(f.readlines())
        return real, pred

    def predict_single(self, img_path_full):
        return get_sliced_prediction(
            img_path_full,
            self.detection_model,
            slice_height=384,
            slice_width=384,
            overlap_height_ratio=0.2,
            overlap_width_ratio=0.2,
            verbose=0)

    def sahi(self, dates=[]):
        path = "./raw/"
        results = []
        for data in dates:
            path_full = os.path.join(path, data)
            img_list = os.listdir(path_full)
            result_single_day = []
            qbar = tqdm.tqdm(range(0, len(img_list), 5))
            for idx in qbar:
                qbar.set_description(f"date={data}, idx={idx}")
                img_list_batch = img_list[idx:idx + 5]
                img_path_list = [os.path.join(path_full, img_name) for img_name in img_list_batch]
                count_single_plant = []
                for img in img_path_list:
                    count_single_leaf = self.predict_single(img)
                    count_single_plant.append(len(count_single_leaf.object_prediction_list))
                
                idx_plant = idx // 5
                idx_leaf = idx % 5
                result_single_plant = {
                    "plant_idx": idx_plant,
                    "leaf_idx": idx_leaf,
                    "count": count_single_plant
                }
                result_single_day.append(result_single_plant)

            results.append(result_single_day)
            
        return results 


    def work(self):
        start_time = time.time()
        batches = list(self.get_batches(num=100))
        real_list = []
        pred_list = []
        qbar = tqdm.tqdm(range(len(batches)))
        for batch_idx in qbar:
            batch = batches[batch_idx]
            qbar.set_description(f"real={sum(real_list)}, pred={sum(pred_list)}")
            real, pred = self.predict_batch(batch)
            real_list.append(real)
            pred_list.append(pred)
        self.results = {'real': real_list, 
                        'pred': pred_list}
        self.time = time.time() - start_time
        print(f"Done in {self.time:.2f} sec!")
    
    def save(self):
        if not os.path.exists("./results/"):
            os.mkdir("./results/")

        results_file_name = f"./results/results_{self.model_name}.json"
        with open(results_file_name, "w") as f:
            f.write(json.dumps(self.results, indent=4, ensure_ascii=False))
        print(f"Results saved to {results_file_name}")
            

In [None]:
predictor = Predictor(model_path="./models/best.pt")
predictor.work()

In [None]:
predictor.save()

## params

In [10]:
from ultralytics.utils.torch_utils import *

In [None]:
model = YOLO('./models/train4_best.pt')

n_p = get_num_params(model)
n_g = get_num_gradients(model)
n_l = len(list(model.modules()))

print(f"{n_l} layers, {n_p} parameters")