In [2]:
import os

import numpy as np
import cv2
from tqdm import tqdm
from PIL import Image
import torch

from sam2.sam2.build_sam import build_sam2
from sam2.sam2.sam2_image_predictor import SAM2ImagePredictor

ModuleNotFoundError: No module named 'hydra'

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
IMAGES_DIR = "/home/aerozrenie/second_stage/filtered/bbcrops_resized/images"
LABELS_BB_DIR = "/home/aerozrenie/second_stage/filtered/bbcrops_resized/labels"
OUTPUT_MASKS_DIR = "/home/aerozrenie/second_stage/filtered/bbcrops_resized/labels_masks"

os.makedirs(OUTPUT_MASKS_DIR, exist_ok=True)

In [None]:
CHECKPOINT_PATH = "/home/aerozrenie/second_stage/segment-anything-2/checkpoints/sam2.1_hiera_large.pt"
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_l.yaml" 

sam2_model = build_sam2(MODEL_CFG, CHECKPOINT_PATH, device=device)
predictor = SAM2ImagePredictor(sam2_model)

In [None]:
image_filenames = [f for f in os.listdir(IMAGES_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPEG', '.JPG', '.PNG'))]
for image_filename in tqdm(image_filenames, desc="Обработка изображений"):

    base_filename = os.path.splitext(image_filename)[0]
    image_path = os.path.join(IMAGES_DIR, image_filename)
    label_bb_path = os.path.join(LABELS_BB_DIR, base_filename + ".txt")
    output_mask_path = os.path.join(OUTPUT_MASKS_DIR, base_filename + ".txt")

    if not os.path.exists(label_bb_path):
        continue

    image_pil = Image.open(image_path).convert("RGB")
    image_np = np.array(image_pil)
    image_height, image_width = image_np.shape[:2]

    input_boxes_with_class_ids = []
    with open(label_bb_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 5: continue
            class_id = int(parts[0])
            x_center_norm, y_center_norm, width_norm, height_norm = map(float, parts[1:5])

            x_center_px = x_center_norm * image_width
            y_center_px = y_center_norm * image_height
            width_px = width_norm * image_width
            height_px = height_norm * image_height

            x_min = x_center_px - (width_px / 2)
            y_min = y_center_px - (height_px / 2)
            x_max = x_center_px + (width_px / 2)
            y_max = y_center_px + (height_px / 2)
            input_box = np.array([x_min, y_min, x_max, y_max])
            input_boxes_with_class_ids.append({'box': input_box, 'class_id': class_id})

    if not input_boxes_with_class_ids:
        continue

    all_output_lines = []
    predictor.set_image(image_np)
    
    with torch.inference_mode(), torch.autocast(device.type, dtype=torch.bfloat16 if device.type == "cuda" else torch.float32):
        for item in input_boxes_with_class_ids:
            current_box = item['box']
            current_class_id = item['class_id']

            masks_pred, _, _ = predictor.predict(
                point_coords=None,
                point_labels=None,
                box=current_box[None, :], 
                multimask_output=False,
            )
            
            mask_tensor = masks_pred[0].squeeze() 
            mask_np_binary = (mask_tensor > 0.5).astype(np.uint8) 

            contours, _ = cv2.findContours(mask_np_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            for contour in contours:
                if contour.shape[0] < 3: continue

                normalized_contour = contour.astype(np.float32)
                normalized_contour[:, 0, 0] /= image_width  
                normalized_contour[:, 0, 1] /= image_height 
                
                normalized_contour = np.clip(normalized_contour, 0.0, 1.0)

                segment_points = normalized_contour.squeeze().reshape(-1).tolist()
                
                if len(segment_points) >= 6 and len(segment_points) % 2 == 0:
                    line_out = f"{current_class_id} " + " ".join(map(str, segment_points))
                    all_output_lines.append(line_out)

    if all_output_lines:
        with open(output_mask_path, 'w') as f_out:
            for line in all_output_lines:
                f_out.write(line + "\n")
    else:
        print(f"Не найдено валидных сегментов для {image_filename}")

print("Обработка завершена.")