Skip to content

Commit

Permalink
perf(auto-labeling): Enhance expandability (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
QIN2DIM committed Oct 25, 2023
1 parent 94646fc commit 6aa00b2
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 116 deletions.
2 changes: 1 addition & 1 deletion automation/assets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def merge(self, fd: Path, td: Path):


def run():
sources = "https://github.com/QIN2DIM/hcaptcha-challenger/issues/864"
sources = "https://github.com/QIN2DIM/hcaptcha-challenger/issues/860"
am = AssetsManager.from_sources(sources)
am.execute()

Expand Down
208 changes: 96 additions & 112 deletions automation/auto_labeling.py
Original file line number Diff line number Diff line change
@@ -1,176 +1,160 @@
# -*- coding: utf-8 -*-
# Time : 2023/10/20 17:28
# Time : 2023/10/24 5:39
# Author : QIN2DIM
# GitHub : https://github.com/QIN2DIM
# Description: zero-shot image classification
# Description:
from __future__ import annotations

import logging
import os
import shutil
import sys
from dataclasses import dataclass
from dataclasses import field
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import List, Tuple
from typing import Tuple, List

from PIL import Image
from hcaptcha_challenger import split_prompt_message, label_cleaning
from hcaptcha_challenger import (
DataLake,
install,
ModelHub,
ZeroShotImageClassifier,
register_pipline,
)
from tqdm import tqdm

project_dir = Path(__file__).parent.parent
db_dir = project_dir.joinpath("database2309")
from flow_card import datalake_card

logging.basicConfig(
level=logging.INFO, stream=sys.stdout, format="%(asctime)s - %(levelname)s - %(message)s"
)

install(upgrade=True)


@dataclass
class AutoLabeling:
positive_labels: List[str] = field(default_factory=list)
candidate_labels: List[str] = field(default_factory=list)
images_dir: Path = field(default=Path)
pending_tasks: List[Path] = field(default_factory=list)
"""
Example:
---
checkpoint = "laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
# checkpoint = "QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336"
1. Roughly observe the distribution of the dataset and design a DataLake for the challenge prompt.
- ChallengePrompt: "Please click each image containing an off-road vehicle"
- positive_labels --> ["off-road vehicle"]
- negative_labels --> ["bicycle", "car"]
output_dir: Path = None
2. You can design them in batches and save them as YAML files,
which the classifier can read and automatically DataLake
def load_zero_shot_model(self):
import torch
from transformers import pipeline
3. Note that positive_labels is a list, and you can specify multiple labels for this variable
if the label pointed to by the prompt contains ambiguity。
device = "cuda" if torch.cuda.is_available() else "cpu"
task = "zero-shot-image-classification"
"""

detector = pipeline(task=task, model=self.checkpoint, device=device, batch_size=8)
input_dir: Path = field(default_factory=Path)
pending_tasks: List[Path] = field(default_factory=list)
tool: ZeroShotImageClassifier = field(default_factory=ZeroShotImageClassifier)

return detector
output_dir: Path = field(default_factory=Path)

@classmethod
def from_prompt(cls, positive_labels: List[str], candidate_labels: List[str], images_dir: Path):
images_dir.mkdir(parents=True, exist_ok=True)
limit: int = field(default=1)
"""
By default, all pictures in the specified folder are classified and moved,
Specifies the limit used to limit the number of images for the operation.
"""

pending_tasks: List[Path] = []
for image_name in os.listdir(images_dir):
image_path = images_dir.joinpath(image_name)
@classmethod
def from_datalake(cls, dl: DataLake, **kwargs):
if not isinstance(dl.joined_dirs, Path):
raise TypeError(
f"The dataset joined_dirs needs to be passed in for auto-labeling. - {dl.joined_dirs=}"
)
if not dl.joined_dirs.exists():
raise ValueError(f"Specified dataset path does not exist - {dl.joined_dirs=}")

input_dir = dl.joined_dirs
pending_tasks = []
for image_name in os.listdir(input_dir):
image_path = input_dir.joinpath(image_name)
if image_path.is_file():
pending_tasks.append(image_path)

return cls(
positive_labels=positive_labels,
candidate_labels=candidate_labels,
images_dir=images_dir,
pending_tasks=pending_tasks,
)

def valid(self):
if not self.pending_tasks:
print("No pending tasks")
return
if len(self.candidate_labels) <= 2:
print(f">> Please enter at least three class names - {self.candidate_labels=}")
return
if (limit := kwargs.get("limit")) is None:
limit = len(pending_tasks)
elif not isinstance(limit, int) or limit < 1:
raise ValueError(f"limit should be a positive integer greater than zero. - {limit=}")

return True
tool = ZeroShotImageClassifier.from_datalake(dl)
return cls(tool=tool, input_dir=input_dir, pending_tasks=pending_tasks, limit=limit)

def mkdir(self) -> Tuple[Path, Path]:
__formats = ("%Y-%m-%d %H:%M:%S.%f", "%Y%m%d%H%M")
now = datetime.strptime(str(datetime.now()), __formats[0]).strftime(__formats[1])
yes_dir = self.images_dir.joinpath(now, "yes")
bad_dir = self.images_dir.joinpath(now, "bad")
yes_dir = self.input_dir.joinpath(now, "yes")
bad_dir = self.input_dir.joinpath(now, "bad")
yes_dir.mkdir(parents=True, exist_ok=True)
bad_dir.mkdir(parents=True, exist_ok=True)

self.output_dir = yes_dir.parent

return yes_dir, bad_dir

def execute(self, limit: int | str = None):
if not self.valid():
def execute(self, model):
if not self.pending_tasks:
logging.info("No pending tasks")
return

# Format datafolder
yes_dir, bad_dir = self.mkdir()

# Load zero-shot model
detector = self.load_zero_shot_model()

desc_in = f'"{self.input_dir.parent.name}/{self.input_dir.name}"'
total = len(self.pending_tasks)
desc_in = f'"{self.checkpoint}/{self.images_dir.name}"'
if isinstance(limit, str) and limit == "all":
limit = total
else:
limit = limit or total

logging.info(f"load {self.tool.positive_labels=}")
logging.info(f"load {self.tool.candidate_labels=}")

with tqdm(total=total, desc=f"Labeling | {desc_in}") as progress:
for image_path in self.pending_tasks[:limit]:
for image_path in self.pending_tasks[: self.limit]:
# The label at position 0 is the highest scoring target
image = Image.open(image_path)
results = self.tool(model, image)

# Binary Image classification
predictions = detector(image, candidate_labels=self.candidate_labels)

# Move positive cases to yes/
# Move negative cases to bad/
if predictions[0]["label"] in self.positive_labels:
# we're only dealing with binary classification tasks here
if results[0]["label"] in self.tool.positive_labels:
output_path = yes_dir.joinpath(image_path.name)
else:
output_path = bad_dir.joinpath(image_path.name)

shutil.move(image_path, output_path)

progress.update(1)


@dataclass
class DataGroup:
positive_labels: List[str] | str
joined_dirs: List[str]
negative_labels: List[str]

def __post_init__(self):
if isinstance(self.positive_labels, str):
self.positive_labels = [self.positive_labels]

@property
def input_dir(self):
return db_dir.joinpath(*self.joined_dirs).absolute()

def auto_labeling(self, **kwargs):
pls = []
for pl in self.positive_labels:
pl = pl.replace("_", " ")
pl = split_prompt_message(label_cleaning(pl), "en")
pls.append(pl)

candidate_labels = pls.copy()
def run():
modelhub = ModelHub.from_github_repo()
modelhub.parse_objects()

if isinstance(self.negative_labels, list) and len(self.negative_labels) != 0:
candidate_labels.extend(self.negative_labels)
model = register_pipline(modelhub)

al = AutoLabeling.from_prompt(pls, candidate_labels, self.input_dir)
al.execute(limit=kwargs.get("limit"))
images_dir = Path(__file__).parent.parent.joinpath("database2309")

return al


def edit_in_the_common_cases():
# prompt to negative labels
# input_dir = /[Project_dir]/database2309/*[joined_dirs]

dg = DataGroup(
positive_labels=["helicopter", "excavator"],
joined_dirs=["motorized_machine"],
negative_labels=["laptop", "chess", "plant", "natural landscape", "mountain"],
)

dg = DataGroup(
positive_labels=["off road vehicle"],
joined_dirs=["off_road_vehicle"],
negative_labels=["bicycle", "car"],
)

nox = dg.auto_labeling(limit="all")
if "win32" in sys.platform and nox.output_dir:
os.startfile(nox.output_dir)
for card in datalake_card:
# Filter out the task cards we care about
if "furniture" not in card["joined_dirs"]:
continue
# Generating a dataclass from serialized data
dl = DataLake(
positive_labels=card["positive_labels"],
negative_labels=card["negative_labels"],
joined_dirs=images_dir.joinpath(*card["joined_dirs"]),
)
# Starts an automatic labeling task
al = AutoLabeling.from_datalake(dl)
al.execute(model)
# Automatically open output directory
if "win32" in sys.platform and al.output_dir.is_dir():
os.startfile(al.output_dir)


if __name__ == "__main__":
edit_in_the_common_cases()
run()
2 changes: 1 addition & 1 deletion automation/continue_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ def run(prompt: str, model_name: str | None = None):


if __name__ == "__main__":
run("plant", "nested_plant")
run("off_road_vehicle")
2 changes: 1 addition & 1 deletion automation/datasets_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
collected = []
per_times = 60
tmp_dir = Path(__file__).parent.joinpath("tmp_dir")
sitekey = "f5561ba9-8f1e-40ca-9b5b-a0b3f719ef34"
sitekey = SiteKey.user_easy


async def collete_datasets(context: ASyncContext):
Expand Down
19 changes: 19 additions & 0 deletions automation/flow_card.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
# Time : 2023/10/26 2:58
# Author : QIN2DIM
# GitHub : https://github.com/QIN2DIM
# Description:
# Run `assets_manager.py` to get test data from GitHub issues

flow_card = [
{
"positive_labels": ["off-road vehicle"],
"negative_labels": ["car", "bicycle"],
"joined_dirs": ["off_road_vehicle"],
},
{
"positive_labels": ["furniture", "chair"],
"negative_labels": ["guitar", "keyboard", "game tool", "headphones"],
"joined_dirs": ["furniture"],
},
]
5 changes: 4 additions & 1 deletion automation/mini_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def upgrade_objects(aid_):
# "vineyard": "vineyard2309",
# "pair_of_roller_skates": "pair_of_roller_skates2310",
# "nested_plant": "nested_plant2311",
"off_road_vehicle": "off_road_vehicle2309"
"off_road_vehicle": "off_road_vehicle2310",
# "nested_largest_turtle": "nested_largest_turtle2309",
# "pair_of_headphones": "pair_of_headphones2309",
# "furniture": "furniture2310",
}
# fmt:on

Expand Down

0 comments on commit 6aa00b2

Please sign in to comment.