In [None]:
%load_ext autoreload
%load_ext aymurai.devtools.magic
%autoreload 2

In [None]:
# %%export aymurai.models.decision.binregex

import os
import shutil
from copy import deepcopy

import regex
import torch
import subprocess
from unidecode import unidecode

from aymurai.logging import get_logger
from aymurai.meta.types import DataItem, DataBlock
from aymurai.utils.misc import is_url, get_element
from aymurai.meta.pipeline_interfaces import TrainModule
from aymurai.meta.environment import AYMURAI_CACHE_BASEPATH
from aymurai.models.decision.tokenizer import Tokenizer
from aymurai.meta.api_interfaces import DocLabel, DocLabelAttributes
from aymurai.models.decision.conv1d import Conv1dTextClassifier
from aymurai.utils.download import download

logger = get_logger(__name__)

# FIXME: when load tokenizer some gpu memory is allocated because spacy. wtf


class DecisionConv1dBinRegex(TrainModule):
    def __init__(
        self,
        tokenizer_path: str,
        model_checkpoint: str,
        device: str = "cpu",
        threshold: float = 0.88,
        return_only_with_detalle: bool = True,
    ):
        self._device = device
        self._tokenizer_path = tokenizer_path
        self._model_path = model_checkpoint
        self.threshold = threshold
        self.return_only_with_detalle = return_only_with_detalle

        # download if needed
        ## tokenizer
        basepath = os.getenv("AYMURAI_CACHE_BASEPATH", AYMURAI_CACHE_BASEPATH)
        if is_url(url := self._tokenizer_path):
            output = f"{basepath}/{self.__name__}/tokenizer.pth"
            logger.info(f"downloading tokenizer on {output}")
            os.makedirs(os.path.dirname(output), exist_ok=True)
            self._tokenizer_path = download(url, output=output)
        # model
        if is_url(url := self._model_path):
            output = f"{basepath}/{self.__name__}/model.ckpt"
            logger.info(f"downloading model on {output}")
            os.makedirs(os.path.dirname(output), exist_ok=True)
            self._model_path = download(url, output=output)

        self.tokenizer = Tokenizer.load(self._tokenizer_path)
        self.model = Conv1dTextClassifier.load_from_checkpoint(
            self._model_path,
            map_location=self._device,
        )
        self.model = self.model.eval()

    def save(self, basepath: str) -> dict | None:
        # save tokenizer
        os.makedirs(basepath, exist_ok=True)
        self._tokenizer_path = f"{basepath}/tokenizer.pth"
        self.tokenizer.save(self._tokenizer_path)
        logger.info(f"tokenizer saved on: {self._tokenizer_path}")

        # save model
        new_model_path = f"{basepath}/model.ckpt"
        shutil.copy(self._model_path, new_model_path)
        self._model_path = new_model_path
        logger.info(f"model saved on: {self._model_path}")
        return {
            "tokenizer_path": self._tokenizer_path,
            "model_checkpoint": self._model_path,
            "device": self._device,
        }

    @classmethod
    def load(cls, path: str, **kwargs):
        return cls(
            tokenizer_path=f"{path}/tokenizer.pth",
            model_checkpoint=f"{path}/model.ckpt",
            **kwargs,
        )

    def fit(self, train: DataBlock, val: DataBlock):
        logger.warning("fit routine not implemented")
        pass

    def predict(self, data: DataBlock) -> DataBlock:
        # FIXME: optimize
        logger.warn("predict not optimized")
        return [self.predict_single(item) for item in data]

    def get_subcategory(self, text):
        pattern_no_hace_lugar = regex.compile(
            r"(?i)(no hacer? lugar|rechaz[ao]r?|no admitir|no convalidar|no autorizar|declarar inadmisible)"
        )
        match = pattern_no_hace_lugar.findall(text)
        if match:
            return ["no_hace_lugar"]
        else:
            return ["hace_lugar"]

    def gen_aymurai_entity(self, text: str, category: int, score: float):
        subcategory = self.get_subcategory(text)
        attrs = DocLabelAttributes(
            aymurai_label="DECISION",
            aymurai_label_subclass=subcategory,
            aymurai_method=self.__name__,
            aymurai_score=score,
        )

        ent = DocLabel(
            text=text,
            start_char=0,
            end_char=len(text),
            attrs=attrs,
        )
        ent = ent.dict()
        ent["label"] = "DECISION"
        ent["context_pre"] = ""
        ent["context_post"] = ""
        return ent

    def predict_single(self, item: DataItem) -> DataItem:
        item = deepcopy(item)

        text = item["data"]["doc.text"]
        text = unidecode(text)
        input_ids = self.tokenizer.encode_batch([text]).to(self.model.device)
        with torch.no_grad():
            log_prob = self.model(input_ids).exp()
        # using category 1 as global score (binary)
        prob = log_prob.detach().numpy()[0, 1]

        category = int(prob > self.threshold)
        score = prob

        if category == 0:  # not a decision
            return item

        ents = get_element(item, ["predictions", "entities"]) or []
        detalles = [ent for ent in ents if ent["label"] == "DETALLE"]
        if self.return_only_with_detalle and not detalles:
            return item

        ent = self.gen_aymurai_entity(text=text, category=category, score=score)
        ents.append(ent)

        if not "predictions" in item:
            item["predictions"] = {}

        item["predictions"]["entities"] = ents

        return item


In [None]:
model = DecisionConv1dBinRegex(
    # tokenizer_path='/resources/pipelines/examples/flair-simple/DecisionConv1dBinRegex/tokenizer.pth',
    # tokenizer_path='tokenizer.pth',
    # # model_checkpoint='/resources/pipelines/examples/flair-simple/DecisionConv1dBinRegex/model.ckpt',
    # model_checkpoint='/workspace/notebooks/experiments/decision/checkpoints/413-torch-binary-emb-conv1d/epoch=8-step=234.ckpt',
    tokenizer_path="https://drive.google.com/uc?id=1eljQOinpObdfBREIKxVnC5Y2g_sbhPHT&confirm=true",
    model_checkpoint="https://drive.google.com/uc?id=19_YmBJnO06iS0qW8ak0zl0EIsJYin8kQ&confirm=true",
    # model_checkpoint='/workspace/notebooks/experiments/decision/checkpoints/413-torch-binary-emb-conv1d/epoch=7-step=208.ckpt',
    # model_checkpoint='/workspace/notebooks/experiments/decision/checkpoints/413-torch-binary-emb-conv1d/epoch=11-step=312.ckpt',
    # model_checkpoint="/workspace/notebooks/experiments/decision/checkpoints/413-torch-binary-emb-conv1d/epoch=16-step=442.ckpt",
    # model_checkpoint='/workspace/notebooks/experiments/decision/checkpoints/413-torch-binary-emb-conv1d/epoch=9-step=260-v1.ckpt',
    device="cpu",
)


In [None]:
model.save('test/conv')

In [None]:
model = model.load('test/conv')

In [None]:
from aymurai.utils.display import DocRender

colors = {
    'DECISION': 'Aquamarine',

}
render = DocRender(config={'colors': colors})

# render(pred)


In [None]:
from aymurai.models.flair.utils import FlairTextNormalize
from aymurai.models.flair.core import FlairModel
from aymurai.pipeline import AymurAIPipeline
from aymurai.models.decision.binregex import DecisionConv1dBinRegex

config = {
    "preprocess": [
        (FlairTextNormalize, {}),
    ],
    "models": [
        (
            FlairModel,
            {
                "basepath": "/resources/pipelines/examples/flair-simple/FlairModel",
                "split_doc": True,
                "device": "cpu",
            },
        ),
        (
            DecisionConv1dBinRegex,
            {
                "tokenizer_path": "tokenizer.pth",
                # "model_checkpoint": "/workspace/notebooks/experiments/decision/test/conv/model.ckpt",
                # "model_checkpoint": "/workspace/notebooks/experiments/decision/checkpoints/413-torch-binary-emb-conv1d/epoch=7-step=208.ckpt",
                "model_checkpoint": "/workspace/notebooks/experiments/decision/checkpoints/413-torch-binary-emb-conv1d/epoch=8-step=234.ckpt",
                "device": "cpu",
            },
        ),
    ],
    "postprocess": [],
    "multiprocessing": {},
    "use_cache": False,
}

pipeline = AymurAIPipeline(config)


In [None]:
# pipeline.save('/resources/pipelines/production/full-pipeline')

In [None]:
from aymurai.datasets.ar_juz_pcyf_10.annotations import ArgentinaJuzgadoPCyF10LabelStudioAnnotations

data = ArgentinaJuzgadoPCyF10LabelStudioAnnotations('/resources/data/restricted/annotations/20221122-bis/').data

In [None]:
data[0]['path']

In [None]:
from aymurai.transforms.entities import EntityToSpans

entity2span_transform = EntityToSpans(field="predictions", span_key="sc")


In [None]:
from copy import deepcopy

idx = 56
idx = 45

example = [
    {"path": "empty", "data": {"doc.text": text.strip()}}
    for text in data[idx]["data"]["doc.text"].splitlines()
]



In [None]:
pred = pipeline.preprocess(example)
pred = [pipeline.predict_single(p) for p in pred]


In [None]:
from aymurai.utils.misc import get_element

def add_score_to_label(ent):
    score = get_element(ent, ['attrs', 'aymurai_score']) or None
    cats = get_element(ent, ['attrs', 'aymurai_label_subclass']) or None
    label = get_element(ent, ['attrs', 'aymurai_label']) or ''

    label = f"{label}:{cats or ''}:{score or '':1.2f}"

    ent['label'] = label
    return ent

In [None]:
for p in pred:
    ents = get_element(p, ['predictions', 'entities']) or []
    ents = [add_score_to_label(ent) for ent in ents]
    p['predictions']['entities'] = ents

    options = {'colors': {e['label']: 'lightblue' for e in ents if e['label'].startswith('DECISION')}}
    
    render(p, style='span', spans_key='sc', config=options)