In [None]:
%load_ext autoreload
%load_ext aymurai.devtools.magic
%autoreload 2

In [None]:
from copy import deepcopy
import shutil

from aymurai.logging import get_logger
from aymurai.meta.types import DataItem, DataBlock
from aymurai.meta.pipeline_interfaces import TrainModule
from aymurai.models.decision.torch.tokenizer import Tokenizer
from aymurai.models.decision.torch.conv1d import Conv1dTextClassifier
from aymurai.meta.api_interfaces import DocLabel, DocLabelAttributes
from aymurai.utils.misc import get_element
import os

logger = get_logger(__name__)

# FIXME: when load tokenizer some gpu memory is allocated because spacy. wtf


class DecisionConv1dMulticlass(TrainModule):
    def __init__(
        self,
        tokenizer_path: str,
        model_checkpoint: str,
        device: str = "cpu",
    ):
        self._device = device
        self._tokenize_path = tokenizer_path
        self._model_path = model_checkpoint

        self.tokenizer = Tokenizer.load(self._tokenize_path)
        self.model = Conv1dTextClassifier.load_from_checkpoint(
            self._model_path,
            map_location=self._device,
        )

    def save(self, basepath: str) -> dict | None:
        # save tokenizer
        os.makedirs(basepath, exist_ok=True)
        self._tokenize_path = f"{basepath}/tokenizer.pth"
        self.tokenizer.save(self._tokenize_path)
        logger.info(f"tokenizer saved on: {self._tokenize_path}")

        # save model
        new_model_path = f"{basepath}/model.ckpt"
        shutil.copy(self._model_path, new_model_path)
        self._model_path = new_model_path
        logger.info(f"model saved on: {self._model_path}")

    @classmethod
    def load(cls, path: str, **kwargs):
        return cls(
            tokenizer_path=f"{path}/tokenizer.pth",
            model_checkpoint=f"{path}/model.ckpt",
            **kwargs,
        )

    def fit(self, train: DataBlock, val: DataBlock):
        logger.warning("fit routine not implemented")
        pass

    def predict(self, data: DataBlock) -> DataBlock:
        # FIXME: optimize
        logger.warn('predict not optimized')
        return [self.predict_single(item) for item in data]

    def gen_aymurai_entity(self, text: str, category: int, score: float):
        attrs = DocLabelAttributes(
            aymurai_label="DECISION",
            aymurai_label_subclass=["no_hace_lugar" if category == 1 else "hace_lugar"],
            aymurai_method=self.__name__,
            aymurai_score=score,
        )

        ent = DocLabel(
            text=text,
            start_char=0,
            end_char=len(text),
            attrs=attrs,
        )
        ent = ent.dict()
        ent["label"] = "DECISION"
        return ent

    def predict_single(self, item: DataItem) -> DataItem:
        item = deepcopy(item)

        text = item["data"]["doc.text"]
        input_ids = self.tokenizer.encode_batch([text]).to(self.model.device)
        log_prob = self.model(input_ids)[0]
        prob = log_prob.exp()

        category = int(prob.argmax())
        score = float(prob[category])

        if category == 0:  # not a decision
            return item

        ents = get_element(item, ["predictions", "entities"]) or []

        ent = self.gen_aymurai_entity(text=text, category=category, score=score)
        ents.append(ent)

        if not "predictions" in item:
            item["predictions"] = {}

        item["predictions"]["entities"] = ents

        return item


In [None]:
Tokenizer.load('tokenizer.pth')

In [None]:
model = DecisionConv1dMulticlass(
    "tokenizer.pth",
    model_checkpoint="/workspace/notebooks/experiments/decision/checkpoints/pl-emb-conv/epoch=38-step=6981.ckpt",
    device="cpu",
)


In [None]:
model.save('test/conv')

In [None]:
model = model.load('test/conv')

In [None]:
pred = model.predict_single(item)
pred

In [None]:
from aymurai.spacy.display import DocRender

colors = {
    'DECISION': 'Aquamarine',

}
render = DocRender(config={'colors': colors})

render(pred)


In [None]:
from aymurai.models.flair.utils import FlairTextNormalize
from aymurai.models.flair.core import FlairModel
from aymurai.pipeline import AymurAIPipeline

config = {
    "preprocess": [
        (FlairTextNormalize, {}),
    ],
    "models": [
        (
            FlairModel,
            {
                "basepath": "/resources/pipelines/examples/flair-simple/FlairModel",
                "split_doc": True,
                "device": "cpu",
            },
        ),
        (
            DecisionConv1dMulticlass,
            {
                "tokenizer_path": "/workspace/notebooks/experiments/decision/tokenizer.pth",
                "model_checkpoint": "/workspace/notebooks/experiments/decision/test/conv/model.ckpt",
                "device": "cpu",
            },
        ),
    ],
    "postprocess": [],
    "multiprocessing": {},
    "use_cache": False,
}

pipeline = AymurAIPipeline(config)


In [None]:
from aymurai.datasets.ar_juz_pcyf_10.annotations import ArgentinaJuzgadoPCyF10LabelStudioAnnotations

data = ArgentinaJuzgadoPCyF10LabelStudioAnnotations('/resources/data/restricted/annotations/20221122-bis/').data

In [None]:
data[0]['path']

In [None]:
a = filter(lambda x: x['path'] == '\\/resources\\/restricted\\/ar-juz-pcyf-10\\/RESOLUCIONES DEL JUZGADO - DOCS\\/Suspensión del proceso a prueba\\/Otorga probation\\/1542.docx', data)
a = list(a)
a

In [None]:
from aymurai.transforms.entities import EntityToSpans

entity2span_transform = EntityToSpans(field="predictions", span_key="sc")


In [None]:
from copy import deepcopy

idx = 300
# idx = 124
idx = 35
# idx = 56
# idx = 75

example = [
    {"path": "empty", "data": {"doc.text": text}}
    for text in a[0]['data']['doc.text'].splitlines()
    # for text in data[idx]["data"]["doc.text"].splitlines()
]



In [None]:
pred = pipeline.preprocess(example)
pred = [pipeline.predict_single(p) for p in pred]


In [None]:
def add_score_to_label(ent):
    score = get_element(ent, ['attrs', 'aymurai_score']) or None
    cats = get_element(ent, ['attrs', 'aymurai_label_subclass']) or None
    label = get_element(ent, ['attrs', 'aymurai_label']) or ''

    label = f"{label}:{cats or ''}:{score or '':1.2f}"

    ent['label'] = label
    return ent

In [None]:
for p in pred:
    # ents = get_element(p, ['predictions', 'entities']) or []
    # p['predictions']['entities'] = [add_score_to_label(ent) for ent in ents]
    render(p, style='span', spans_key='sc')