# Модель YOLOandMobileSAM

In [None]:
import os
import torch
import cv2
import matplotlib.pyplot as plt
import numpy as np
from helper import get_prompt, get_mask_image, show_masks, show_boxes, overlay_masks_on_black_background
from mobile_sam import sam_model_registry, SamPredictor
from ultralytics import YOLO
import count_area

In [4]:
class YOLOandMobileSAM:
    def __init__(self, yolo_weights='./weights/best_nano_val50.pt', sam_weights='./weights/mobile_sam.pt'):
        self.yolo_model = self._load_yolo(yolo_weights)
        self.sam_predictor = self._load_mobile_sam(sam_weights)

    def _load_yolo(self, weights_path):
        """Загружает предобученную модель YOLO"""
        return YOLO(weights_path)

    def _load_mobile_sam(self, sam_checkpoint):
        """Загружает Mobile SAM модель"""
        model_type = "vit_t"
        device = "cuda" if torch.cuda.is_available() else "cpu"
        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        sam.to(device=device)
        sam.eval()
        return SamPredictor(sam)

    def predict_yolo(self, img_path, conf=0.5):
        """Прогоняет изображение через YOLO и возвращает результат"""
        self.yolo_model.predict(img_path, imgsz=640, conf=conf, save_conf=True)
        return self.yolo_model

    def run_predict(self, image_path, show=False, save_path=None):
        """Выполняет полный пайплайн: YOLO + SAM"""
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        self.sam_predictor.set_image(image)

        yolo_result = self.predict_yolo(image_path)
        boxes = get_prompt(yolo_result)

        input_boxes = torch.tensor(boxes, device='cpu')
        transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(
            input_boxes, image.shape[:2]
        )

        masks, iou_predictions, low_res_masks = self.sam_predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False
        )

        if show:
            plt.figure(figsize=(10, 10))
            plt.imshow(image)
            show_masks(masks, plt.gca())
            show_boxes(boxes, plt.gca())
            plt.axis('off')
            plt.show()

        if save_path:
            overlay_masks_on_black_background(masks, save_path)

        return masks, iou_predictions, low_res_masks

    def process_folder(self, images_folder, labels_folder, results_folder):
        """
        Последовательно обрабатывает все изображения из папки
        """
        image_files = sorted(os.listdir(images_folder))

        for image_file in image_files:
            image_path = os.path.join(images_folder, image_file)
            print(f"Processing {image_file}...")

            masks, _, _ = self.run_predict(image_path=image_path, show=False, save_path=True)

            res_mask_path = os.path.join(results_folder, 'res.jpg')

            count_area.calculate_area_and_perimeter(image_file, res_mask_path)

In [None]:
model = YOLOandMobileSAM()
images_folder = 'C:/Users/z.kate/source/GitHubRepos/ELC_Fall_git/YOLO_SAM/datasets/valid_all/images/'
labels_folder = 'C:/Users/z.kate/source/GitHubRepos/ELC_Fall_git/YOLO_SAM/datasets/valid_all/labels/'
results_folder = 'C:/Users/z.kate/source/GitHubRepos/ELC_Fall_git/YOLO_SAM/datasets/valid_all/res_masks/'

# pipeline.process_folder(images_folder, labels_folder, results_folder)