## ResNet of roboflow

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/captcha-challenger/hcaptcha-model-factory/blob/main/automation/roboflow_resnet.ipynb)

In [None]:
!nvidia-smi

## Exporting the configuration from [`zip_dataset.py`](https://github.com/CaptchaAgent/hcaptcha-model-factory/blob/main/automation/zip_dataset.py)

In [None]:
GITHUB_TOKEN = ""
task_name = "robot"
onnx_archive_name = "robot2312"

# If you are training a nested model,
# you need to specify the `solver.handle(prompt)` it is bound to.
# Otherwise, please set it to empty
NESTED_PROMPT = ""

## No need to change the code below

In [None]:
"""
Injecting Mystical Power into AutoDL
"""
from pathlib import Path
import os
import subprocess

IS_AUTODL_PLATFORM = None
if Path("/root/autodl-pub").exists() and Path("/root/autodl-tmp").exists():
    IS_AUTODL_PLATFORM = True
    result = subprocess.run(
        'bash -c "source /etc/network_turbo && env | grep proxy"',
        shell=True,
        capture_output=True,
        text=True,
    )
    output = result.stdout
    for line in output.splitlines():
        if "=" in line:
            var, value = line.split("=", 1)
            os.environ[var] = value

In [None]:
!pip install loguru onnx fire hcaptcha-challenger
!git clone https://github.com/CaptchaAgent/hcaptcha-model-factory.git
!mv -f hcaptcha-model-factory/src src
!rm -rf hcaptcha-model-factory/

In [None]:
import os

import hcaptcha_challenger as solver

os.environ["GITHUB_TOKEN"] = GITHUB_TOKEN
onnx_archive_name = onnx_archive_name.replace(".onnx", "")

solver.diagnose_task(task_name)

Upload zip type datasets to `[PROJECT_DIR]/`

In [None]:
from pathlib import Path
from loguru import logger
import zipfile

this_dir = Path(os.path.abspath("."))
project_dir = this_dir

model_dir = project_dir.joinpath("model")
factory_data_dir = project_dir.joinpath("data")
source_dir = project_dir.joinpath("src")
zip_path = project_dir.joinpath(f"{task_name}.zip")

if not zip_path.exists():
    raise FileNotFoundError
with zipfile.ZipFile(zip_path) as z:
    z.extractall(factory_data_dir.joinpath(task_name))

Training

In [None]:
%cd {source_dir}

from factories.resnet import ResNet

# - INPUT: `[PROJECT]/data/<task_name>`
# - OUTPUT: `[PROJECT]/model/<task_name>/<task_name.onnx>`

model = ResNet(
    task_name=task_name,
    epochs=None,  # default to 200
    batch_size=None,  # default to 4
    dir_dataset=str(factory_data_dir),
    dir_model=str(model_dir),
)
model.train()
model.conv_pth2onnx(verbose=False)

Deploy model to GitHub

In [None]:
import locale

locale.getpreferredencoding = lambda d=True: "UTF-8"

!pip install PyGithub

import shutil
import sys
from github import Auth, Github, GithubException


def quick_development():
    auth = Auth.Token(os.getenv("GITHUB_TOKEN"))
    repo = Github(auth=auth).get_repo("QIN2DIM/hcaptcha-challenger")
    modelhub_title = "ONNX ModelHub"

    model_path = model_dir.joinpath(task_name, f"{task_name}.onnx")
    pending_onnx_path = model_dir.joinpath(task_name, f"{onnx_archive_name}.onnx")
    shutil.copy(model_path, pending_onnx_path)

    for release in repo.get_releases():
        if release.title != modelhub_title:
            continue
        try:
            asset = release.upload_asset(path=str(pending_onnx_path))
        except GithubException as err:
            if err.status == 422:
                logger.error(
                    f"The model file already exists, please manually replace the file with the same name - url={repo.releases_url}",
                    url=repo.releases_url,
                )
        except Exception as err:
            logger.error(err)
        else:
            logger.success(
                f"Model file uploaded successfully "
                f"- name={asset.name} url={asset.browser_download_url}"
            )
            return asset.id


if not os.getenv("GITHUB_TOKEN"):
    logger.warning("Skip model deployment, miss GITHUB TOKEN")
    sys.exit()

aid = quick_development()
aid

Rolling upgrade

In [None]:
import inspect
import os
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Any, List

import yaml
from github.GitReleaseAsset import GitReleaseAsset
from github.Repository import Repository
from hcaptcha_challenger import ModelHub, handle
import hcaptcha_challenger as solver


@dataclass
class Objects:
    branches: Dict[str, Any]
    circle_seg: str
    clip_candidates: Dict[str, List[str]]
    nested_categories: Dict[str, List[str]]
    ashes_of_war: Dict[str, Any]
    label_alias: Dict[str, Any]
    datalake: dict

    @classmethod
    def from_modelhub(cls, modelhub: ModelHub):
        data = yaml.safe_load(modelhub.objects_path.read_text(encoding="utf8"))
        return cls(
            **{
                key: (data[key] if val.default == val.empty else data.get(key, val.default))
                for key, val in inspect.signature(cls).parameters.items()
            }
        )

    def to_yaml(self, path: Path = None):
        path = path or Path("objects-tmp.yaml")
        with open(path, "w", encoding="utf8") as file:
            yaml.safe_dump(self.__dict__, file, sort_keys=False, allow_unicode=True)
        return path

    @staticmethod
    def to_asset(repo: Repository, data_tmp_path: Path, message: str = ""):
        content = data_tmp_path.read_bytes()
        message = message or f"Automated deployment @ utc {datetime.utcnow()}"
        remote_path = "src/objects.yaml"
        sha = repo.get_contents(path=remote_path).sha
        return repo.update_file(
            branch="main", path=remote_path, message=message, content=content, sha=sha
        )


class Annotator:
    auth = Auth.Token(os.getenv("GITHUB_TOKEN"))
    repo = Github(auth=auth).get_repo("QIN2DIM/hcaptcha-challenger")

    def __init__(self, asset_id: int, matched_label: str = ""):
        self._asset_id = asset_id
        self._matched_label = matched_label

        self._asset: GitReleaseAsset | None = None

        solver.install(upgrade=True)

        self.modelhub = ModelHub.from_github_repo()
        self.modelhub.parse_objects()

        self.data: Objects = Objects.from_modelhub(modelhub=self.modelhub)

    @property
    def asset(self):
        if not self._asset:
            self._asset = self.repo.get_release_asset(self._asset_id)
        return self._asset

    @staticmethod
    def parse_resnet_label(asset_name: str) -> str:
        """
        asset_name: dog2312.onnx chess_piece2309.onnx
        """
        onnx_archive = asset_name.replace(".onnx", "")
        i_end = -1
        for i, s in enumerate(onnx_archive):
            if s.isdigit():
                i_end = i
                break
        label = onnx_archive[:i_end]
        label = label.replace("_", " ")
        return label

    def handle_resnet_objects(self):
        onnx_archive = self.asset.name.replace(".onnx", "")
        matched_label = self._matched_label or self.parse_resnet_label(self.asset.name)
        old_onnx_archive = self.modelhub.label_alias.get(matched_label)

        # Match: create new case
        if not old_onnx_archive:
            self.data.label_alias[onnx_archive] = {"en": [matched_label]}
        # Match: update old case
        else:
            i18n_mapping = self.data.label_alias[old_onnx_archive].copy()
            del self.data.label_alias[old_onnx_archive]
            self.data.label_alias[onnx_archive] = i18n_mapping

    def handle_nested_objects(self, model_pending: str):
        """
        Match nested cases:
        - the largest animal
        - the smallest animal
        """
        bond_nested_prompt = handle(self._matched_label)
        if not bond_nested_prompt:
            raise ValueError("Nested model requires binding prompt")

        # nested_largest_dog2309.onnx nested_largest_elephant2309.onnx
        prefix_tag_pending = self.parse_resnet_label(model_pending)

        # Match: 已注册的嵌套类型（bond_nested_prompt）
        if nested_models := self.modelhub.nested_categories.get(bond_nested_prompt, []):
            # prompt已注册但被错误赋值
            if not isinstance(nested_models, list):
                # 如果存在确切的值，则返回错误
                if nested_models:
                    raise TypeError(
                        f"NestedTypeError ({bond_nested_prompt}) 的模型映射列表应该是个 List[str] 类型，但实际上是 {type(nested_models)}"
                    )
                # 如果prompt存在但未被赋有效值，则尝试恢复程序重建秩序
                nested_models = []
            # 查询 prompt 对应的模型匹配列表，更新「同项模型」的版本索引
            idx_old_points: List[int] = []
            for i, model_name in enumerate(nested_models):
                prefix_tag_in_the_slot = self.parse_resnet_label(model_name)
                if prefix_tag_in_the_slot == prefix_tag_pending:
                    idx_old_points.append(i)
            # 若 prompt 对应的模型匹配列表找不到「同项模型」更旧的版本，则直接插入新的模型
            for i in idx_old_points:
                nested_models.pop(i)
            nested_models.append(model_pending)
        # Match: 未注册的嵌套模型
        else:
            nested_models = [model_pending]

        # 恢复嵌套模型的上下文，更新模型索引
        self.data.nested_categories[bond_nested_prompt] = nested_models

    def flush_remote_objects(self):
        data_tmp_path = self.data.to_yaml()

        res = self.data.to_asset(
            self.repo,
            message=f"ci(annotator): update model `{self.asset.name}`",
            data_tmp_path=data_tmp_path,
        )

        logger.success(f"upgrade objects", response=res)

        os.remove(data_tmp_path)

    def execute(self):
        logger.debug(f"capture asset", name=self.asset.name, url=self.asset.browser_download_url)

        # Match: ResNet MoE models
        if "yolov8" in self.asset.name:
            return
        if "nested_" in self.asset.name:
            self.handle_nested_objects(self.asset.name)
        else:
            self.handle_resnet_objects()

        self.flush_remote_objects()


def rolling_upgrade(asset_id=None, matched_label: str = ""):
    """
    When uploading a nested model, you need to specify the nesting type of the model binding.
    """
    if not asset_id:
        return

    try:
        annotator = Annotator(asset_id, matched_label=matched_label)
        annotator.execute()
    except Exception as err:
        logger.warning(err)


if not os.getenv("GITHUB_TOKEN"):
    logger.warning("Skip the rolling upgrade task, miss GITHUB TOKEN")
    sys.exit()

rolling_upgrade(asset_id=aid, matched_label=NESTED_PROMPT)

In [None]:
"""
energy conservation
"""
if IS_AUTODL_PLATFORM:
    os.system("unset http_proxy && unset https_proxy")
    os.system("/usr/bin/shutdown")