# 1. Multilabel classifier

## Data and categories 

In [None]:
import pandas as pd

In [None]:
df = pd.read_json('datasets/es_sum_mini.json', lines=True, orient='records') ; df.head()

In [None]:
categories = ['positivo', 'negativo', 'economía', 'electricidad', 'telecomunicaciones', 'ecología', 'política', 'energía']

## Pretrained zero-shot

In [None]:
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="typeform/squeezebert-mnli") # too big: "joeddav/xlm-roberta-large-xnli")

In [None]:
classifier("A ERC y Crida per Sabadell (CUP), que hasta ahora..", candidate_labels=categories, multi_class=True)

## Log predictions in rubrix

In [None]:
import rubrix
import os

In [None]:
rubrix.init(api_url=os.environ["RUBRIX_API_URL"], api_key=os.environ["RUBRIX_API_KEY"])

In [None]:
for i,r in df[0:100].iterrows():
    # zero-shot prediction
    preds = classifier(r.summary, candidate_labels=categories, multi_class=True)
    item = rubrix.TextClassificationRecord(
        inputs={"text": r.summary},
        prediction=[tuple(zip(preds['labels'],preds['scores']))],
        prediction_agent="dvilasuero",
        multi_label=True,
        metadata={'model': 'typeform/squeezebert-mnli'},
        event_timestamp=r['date'],
    )
    # log one record each time
    rubrix.log(records=item, dataset="red_electrica_multilabel")

# 2. Entity classifier

## spaCy pretrained model

Not the best in town

In [None]:
import spacy
nlp = spacy.load('es')

In [None]:
doc = nlp('Esto es una prueba sobre Mariano Rajoy, ex-presidente del PP, la loca de Pontevedra')
for e in doc.ents:
    print(e.start_char, e.end_char, e.label_)

## Log predictions in rubrix

In [None]:
for i,r in df[0:100].iterrows():
    doc = nlp(r['summary'])
    record = rubrix.TokenClassificationRecord(
        text=r['summary'],
        tokens=[t.text for t in doc],
        prediction=[(e.label_, e.start_char, e.end_char) for e in docs.ents],
        prediction_agent="spacy_v2",
        metadata={'model': 'spacy_es_core_news_sm'},
        event_timestamp=r['date'],
    )
    rubrix.log(record, dataset="red_electrica_entities", tags={"task":"ner"})    