## Dataset

In [1]:
import os
import time
import hashlib
import uuid
import requests
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.chrome.options import Options
from pathlib import Path
from dotenv import load_dotenv
import logging
import csv
from tqdm import tqdm

In [2]:
ROOT_PATH = Path(r"C:\prog\py\piv\pivo-segmentation")

DATA_PATH = ROOT_PATH / "data"
MODEL_PATH = ROOT_PATH / "models"

logging_format = "%(name)s - %(asctime)s - %(levelname)s - %(message)s"

load_dotenv()
logging.basicConfig(
    level=logging.WARNING,
    format=logging_format,
    datefmt="%H:%M:%S",
    # filename=ROOT_PATH / "log.txt",
    # filemode="a+",
    encoding="utf-8",
)   
logger = logging.getLogger(__name__)

In [3]:
def compute_md5(image_bytes):
    """
    Compute the MD5 hash of the given image bytes.
    This simple deduplication method can be replaced by perceptual hashing for fuzzy matching.
    """
    return hashlib.md5(image_bytes).hexdigest()

In [4]:
def download_image(img_url, output_dir, dedup_set):
    """
    Downloads an image from the given URL, performs deduplication using MD5 hash,
    and saves it with a unique name if it is not a duplicate.

    Returns:
        file_path (str): Full path to the saved image or None if duplicate/failure.
    """
    try:
        response = requests.get(img_url, timeout=10)
        if response.status_code == 200:
            image_bytes = response.content
            # Compute hash for deduplication
            hash_val = compute_md5(image_bytes)
            if hash_val in dedup_set:
                return None
            dedup_set.add(hash_val)
            # Create a unique filename using a UUID and the hash value
            unique_name = f"{uuid.uuid4().hex}_{hash_val}.png"
            file_path = os.path.join(output_dir, unique_name)
            with open(file_path, "wb") as f:
                f.write(image_bytes)
            return Path(file_path)
    except Exception as e:
        logger.error(f"Error downloading image {img_url}: {e}")
    return None

In [5]:
def yandex_images_batch_generator(prompt, output_dir, batch_size=10):
    """
    Generator that queries Yandex Images for a fixed search query, downloads images in batches of `batch_size`,
    and yields a list of file paths to the downloaded images.
    
    Deduplication is performed by computing the MD5 hash of the image bytes.
    A more robust method might use perceptual hashing (e.g., using the imagehash library).
    
    Parameters:
        output_dir (str): Directory where images will be saved.
        batch_size (int): Number of images per batch.
    
    Yields:
        batch (list): List of file paths for the downloaded images in the batch.
    """
    os.makedirs(output_dir, exist_ok=True)
    dedup_set = set()
    
    chrome_options = Options()
    chrome_options.add_argument("--headless")
    driver = webdriver.Chrome(options=chrome_options)
    
    url = f"https://yandex.ru/images/search?text={prompt}"
    driver.get(url)
    time.sleep(3)  # Wait for the page to load
    
    batch = []
    idx = 0
    images = driver.find_elements(By.CSS_SELECTOR, "img")
    
    while True:
        # Scroll if necessary to load more images
        if idx >= len(images):
            driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
            time.sleep(2)
            images = driver.find_elements(By.CSS_SELECTOR, "img")
            if idx >= len(images):
                break  # No more images to process
        
        try:
            image_element = images[idx]
            # Get image URL. Sometimes the real URL is in the 'src' attribute, sometimes in 'data-src'
            img_url = image_element.get_attribute("src")
            if not img_url or img_url.startswith("data:"):
                idx += 1
                continue
            file_path = download_image(img_url, output_dir, dedup_set)
            if file_path:
                batch.append(file_path)
            if len(batch) == batch_size:
                yield batch
                batch = []  # Reset batch after yielding
            idx += 1
        except Exception as e:
            logger.error(f"Error processing image index {idx}: {e}")
            idx += 1

    # Yield any remaining images
    if batch:
        yield batch
    driver.quit()

## SAM

In [6]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import pandas as pd
import torch
from torch import autocast
from torch.cuda.amp import GradScaler
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from pathlib import Path
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

using device: cuda


In [8]:
mask_generator = SAM2AutomaticMaskGenerator.from_pretrained(
    model_id="facebook/sam2.1-hiera-large",
    points_per_side=64,
    points_per_batch=128,
    pred_iou_thresh=0.85,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.5,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)

In [9]:
def bbox_ratio(bbox):
    _, _, w, h = bbox
    return h / w

In [10]:
def filter_by_constrain(bboxes, constrain=(1.5, 5), metric=bbox_ratio):
    scalars = np.array([metric(i) for i in bboxes])
    mask1 = scalars < constrain[1]
    mask2 = scalars > constrain[0]
    return np.array(bboxes)[mask1 & mask2]

In [11]:
def suppress_inside_bboxes(bboxes, threshold=0.7):
    """
    Suppresses bounding boxes that are mostly inside larger ones.
    
    Args:
        bboxes (list or np.ndarray): List of bboxes in [x, y, w, h] format.
        threshold (float): Fraction of smaller box area that must be inside the larger one to suppress.
        
    Returns:
        list: Filtered list of bboxes.
    """
    def bbox_area(b):
        return b[2] * b[3]

    def intersection_area(b1, b2):
        x1, y1 = max(b1[0], b2[0]), max(b1[1], b2[1])
        x2 = min(b1[0] + b1[2], b2[0] + b2[2])
        y2 = min(b1[1] + b1[3], b2[1] + b2[3])
        if x2 <= x1 or y2 <= y1:
            return 0
        return (x2 - x1) * (y2 - y1)

    bboxes = np.array(list(bboxes))
    keep = [True] * len(bboxes)

    for i in range(len(bboxes)):
        for j in range(len(bboxes)):
            if i == j or not keep[i] or not keep[j]:
                continue

            area_i = bbox_area(bboxes[i])
            area_j = bbox_area(bboxes[j])
            inter = intersection_area(bboxes[i], bboxes[j])

            if area_i < area_j:
                if inter / area_i >= threshold:
                    keep[i] = False
            elif area_j < area_i:
                if inter / area_j >= threshold:
                    keep[j] = False

    return bboxes[keep]


In [19]:
def get_image_and_bboxes(image_path, mask_generator):
    try:
        image = Image.open(image_path)
        image = image.convert("RGB")
        image = np.array(image)
        raw_mask = mask_generator.generate(image)
        bbox = np.array([i["bbox"] for i in raw_mask])
        bbox = filter_by_constrain(bbox)
        bbox = suppress_inside_bboxes(bbox)
        
        return image, bbox
    except Exception as e:
        logger.error(f"Error in get_image_and_bboxes - {e}")
        return None, None
    

In [13]:
def extract_and_save_bboxes(image: np.array, bboxes: list, bbox_dir: str, csv_dir: str, image_path: str):
    """
    Extracts bounding boxes from an image, saves each cropped region as a PNG with a filename
    based on the parent image name, and creates a CSV file listing each bounding box's details.
    
    Parameters:
        image (np.array): Input image in numpy array format.
        bboxes (list): List of bounding boxes in (x, y, w, h) format.
        output_dir (str): Directory where the cropped images will be saved.
        csv_path (str): Directory where the CSV file will be saved.
        image_path (str): Path to the parent image file (used to derive the image name).
    
    CSV File:
        The CSV file will have the following columns:
        X, Y, W, H, FileName
        
        It will be saved with the same base name as the parent image, with '.csv' appended.
    
    Returns:
        None
    """
    os.makedirs(bbox_dir, exist_ok=True)
    os.makedirs(csv_dir, exist_ok=True)
    
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    
    csv_rows = []
    
    for idx, (x, y, w, h) in enumerate(bboxes.astype(int)):
        cropped = image[y:y+h, x:x+w]
        file_name = f"{base_name}_{idx}.png"
        file_path = os.path.join(bbox_dir, file_name)
        cv2.imwrite(file_path, cropped)
        csv_rows.append([x, y, w, h, file_name])
    
    csv_file_name = f"{base_name}.csv"
    csv_file_path = os.path.join(csv_dir, csv_file_name)
    
    with open(csv_file_path, mode='w', newline='') as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(["X", "Y", "W", "H", "FileName"])
        writer.writerows(csv_rows)
    


In [20]:
def image_pipeline(image_path, mask_generator, bboxes_dir, csv_dir):
    image, bbox = get_image_and_bboxes(image_path, mask_generator)
    if image is None or len(bbox) == 0:
        return False
    extract_and_save_bboxes(image, bbox, bboxes_dir, csv_dir, image_path)
    return True


In [15]:
images_dir = Path(r"C:\prog\py\piv\dataset_raw\my_group")
bboxes_dir = Path(r"C:\prog\py\piv\dataset_raw\bboxes")
csv_dir = Path(r"C:\prog\py\piv\dataset_raw\csv")

os.makedirs(bboxes_dir, exist_ok=True)
os.makedirs(csv_dir, exist_ok=True)

In [21]:
logger.warning("Start")
image_paths = list(images_dir.iterdir())
for path in tqdm(image_paths):
    try:
        status = image_pipeline(path, mask_generator, bboxes_dir, csv_dir)
        if not status:
            logger.error 
    except Exception as e:
        logger.error(f"path:{path}, exception:{e}")

logger.warning("End")



Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  masks = self._transforms.postprocess_masks(
100%|██████████| 432/432 [3:07:25<00:00, 26.03s/it]  


In [15]:
prompts = [
    "полка с пивом в магазине",
    "полка с алкоголем в магазине",
    "полка с напитками в магазине",
]

num_images_per_prompt = [
    500,
    300,
    200,
]

batch_size = 50

downloaded_dir = DATA_PATH / "downloaded"
bboxes_dir = DATA_PATH / "bboxes"
csv_dir = DATA_PATH / "csv"

In [None]:
already_downloaded = len(list(downloaded_dir.iterdir())) if downloaded_dir.exists() else 0
logger.warning("Start")

for prompt, num_images in zip(prompts, num_images_per_prompt):
    generator = yandex_images_batch_generator(prompt, downloaded_dir, batch_size=batch_size)

    with tqdm(total=num_images, desc=prompt) as bar:
        bar.update(already_downloaded)
        while True:
            try:
                batch = next(generator)
                counter = 0
                for path in batch:
                    status = image_pipeline(path, mask_generator, bboxes_dir, csv_dir)
                    if status:
                        counter += 1
                    else:
                        try:
                            os.remove(path)
                        except:
                            pass
                        
                bar.update(counter)
                already_downloaded += len(batch)

                if already_downloaded >= num_images:
                    break
            except StopIteration:
                break
logger.warning("End")


Skipping the post-processing step due to the error above. You can still use SAM 2 and it's OK to ignore the error above, although some post-processing functionality may be limited (which doesn't affect the results in most cases; see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).
  masks = self._transforms.postprocess_masks(


In [22]:
from collections import Counter
Counter([i.suffix for i in downloaded_dir.iterdir()])

Counter({'.png': 605})

In [19]:
faulted = []
for image_path in downloaded_dir.iterdir():
    try:
        image = Image.open(image_path)
        pd.read_csv(image_path.parent.parent / "csv"/ (image_path.stem + ".csv"))
    except:
        faulted.append(image_path)
len(faulted)

0

In [25]:
num_of_bboxes = pd.DataFrame([len(pd.read_csv(i)) for i in csv_dir.iterdir()])

In [None]:
num_of_bboxes.plot.()

TypeError: PlotAccessor.hexbin() missing 2 required positional arguments: 'x' and 'y'

In [None]:
for i in faulted:
    # os.remove(i)
    pass

In [16]:
len(list(downloaded_dir.iterdir())), len(list(csv_dir.iterdir())), len(list(bboxes_dir.iterdir()))

(50, 6, 582)

In [37]:
a = """awkward
clumsy
complex
frustrating
time
-consuming
tricky
accept
adapt
be a step forward
be capable of
be frightened of
can’t take
cope with
get a grip
resist
survive
tackle
underestimate
be a waste of time
be hard to operate
drive crazy
lose patience
get on nerves
amateur
binge
-watch
DIY (do-it-yourself)
revive
preserve
adaptable
endangered
poisonous
atmosphere
exploration
investigation
satellite
species
surface
resources
creature
habitat
survivor
pond
observe
volcano
launch
to monitor
preserve
use up
fossil
head straight for
microbe
mammal
bizarre
creepy
disgusting
fabulous
impressive
irritating
satisfying
stunning
tense
uneasy
weird
to attract attention
to be extrovert
to be introvert
to be reserved
to be the life of the party
to feel left out
to interact with people
socialize
speak softly
to show off
to speak up
That’s gross!
cookware
database
intern
personal statement
stereotype
constructive
destructive
unreasonable
valid
aspect
weakness
assist
build trust
oversee
steer sb away from
assess
draw attention to
point out
think through
weigh the pros and cons
be (get) stuck with
anxiety level
breathing technique
be conscious of
be in control of
be rational
be scared to death
cure an illness
overcome my fear
panic about sth
regain control
try a therapy
be dying to
be eager to
be more than happy to
be passionate about
be prepared to
be reluctant to
be unwilling to
have no desire to
hesitate to
barrier
face time
gesture
isolation
millennial
catch sb’s attention
get hits
get publicity
have a good reputation
make an appearance
make headlines
praise sb
raise awareness
seek fame
announce
boast
confirm
deny
estimate
insist
propose
swear
Then it hit me.
trade""".replace("\n-", "-").split("\n")
b = []
for i in range(0, len(a), len(a) // 4):
    b.append(a[i:i+len(a)//4])


In [42]:
print(*b[4],sep="\n")

IndexError: list index out of range