First of all, we need to import libraries, additionally to `liqfit` we need `transformers` for using models from **HuggingFace Hub** and `datasets` to download a necessary dataset.

In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import TrainingArguments, Trainer
from liqfit.datasets import NLIDataset
from liqfit.collators import NLICollator
from datasets import load_dataset

### Dataset

In this case, we will use the `dair-ai/emotion` dataset. It consists of *English Twitter messages* with six basic **emotions**: *anger*, *fear*, *joy*, *love*, *sadness*, and *surprise*. 

We will randomly select, on average, 8 examples per label for the training set and will use the whole test set for the testing. 

In [None]:

#emotion
emotion_dataset = load_dataset("dair-ai/emotion")

test_dataset = emotion_dataset['test']

classes = test_dataset.features["label"].names

N = 8
train_dataset = emotion_dataset['train'].shuffle(seed=41).select(range(len(classes)*N))

After downloading and sampling the dataset, we need to preprocess it to adapt to the **Natural Language Inference (NLI)** format. 

NLI is a crucial task in NLP that focuses on determining the logical relationship between two given text segments: a premise and a hypothesis. 

The goal is to ascertain whether the hypothesis can be inferred from the premise, falling into one of three categories: entailment, contradiction, or neutral. 

In our case, we form pairs of text and statements with class, like *"This text belongs to Business class"*, including negative statements. And then, we train a classifier to distinguish whether the statement is true or false. 

To transform dataset we use class method of `NLIDataset` called `load_dataset`, it takes the following parameters:

**Args**

- `dataset` (Optional[Dataset], optional): Instance of Huggingface Dataset class. Defaults to None.

- `dataset_name` (Optional[str], optional): Dataset name to load from Huggingface datasets. Defaults to None.

- `classes` (Optional[List[str]], optional): List of classes. Defaults to None.

- `text_column` (Optional[str], optional): Text column name. Defaults to 'text'.

- `label_column` (Optional[str], optional): Label column name. Defaults to 'label'.

- `template` (Optional[str], optional): Template string that will be used for Zero-Shot training/prediction. Defaults to 'This example is {}.'.

- `normalize_negatives` (bool, optional): Whether to normalize the amount of negative examples per each positive example of a class. Defaults to False.

- `positives` (int, optional): Number of positive examples to generate per source. Defaults to 1.

- `negatives` (int, optional): Number of negative examples to generate per source. Defaults to -1.

- `multi_label` (bool, optional): Whether each example has multiple labels or not. Defaults to False.

In [None]:
nli_train_dataset = NLIDataset.load_dataset(train_dataset, classes = classes)
nli_test_dataset = NLIDataset.load_dataset(test_dataset, classes = classes)

### Model initialization

In our case, we will use the `knowledgator/comprehend-it-base` model. It was trained on multiple natural language inference datasets and demonstrated superior performance in zero-shot text classification. 

Moreover, the model can be used for multiple information extraction tasks in zero-shot setting.
Possible use cases of the model:
* Text classification
* Reranking of search results;
* Named-entity recognition;
* Relation extraction;
* Entity linking;
* Question-answering;


In [None]:
model_path = 'knowledgator/comprehend_it-base'

tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSequenceClassification.from_pretrained('knowledgator/comprehend_it-base')

In the cases when you want to use another loss or classification head, we recommend to use `LiqFitModel` class.

In [None]:
from liqfit.modeling import LiqFitModel
from liqfit.losses import FocalLoss

backbone_model = AutoModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-xsmall')

loss_func = FocalLoss()

model = LiqFitModel(backbone_model.config, backbone_model, loss_func=loss_func)

### Training

We will use **transformers** `Trainer` and collate data with **LiqFit**  `NLICollator`. You can set your training parameters to whatever you think is needed.

In [None]:

data_collator = NLICollator(tokenizer, max_length=128, padding=True, truncation=True)


training_args = TrainingArguments(
    output_dir='comprehendo',
    learning_rate=3e-5,
    per_device_train_batch_size=3,
    per_device_eval_batch_size=3,
    num_train_epochs=9,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_steps = 5000,
    save_total_limit=3,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=nli_train_dataset,
    eval_dataset=nli_test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

### Testing

We will test the fine-tuned model in a few-shot setting with the `zero-shot-classification` pipeline of the **transformers** library. Than we will calculate basic classification metrics, such as *accuracy*, *precision*, *recall*, *F1 score* wuth **sklearn** `classification_report`.

In [None]:
from transformers import pipeline
from sklearn.metrics import classification_report
from tqdm import tqdm
import torch

device = torch.device('cuda:0')

classifier = pipeline("zero-shot-classification",
                      model=model,tokenizer=tokenizer, device=device)

In [None]:
from tqdm import tqdm

label2idx = {label: id for id, label in enumerate(classes)}

preds = []
template = 'This example is {}.'
new_classes =  [template.format(c) for c in classes]
label2idx = {label: id for id, label in enumerate(new_classes)}

for example in tqdm(test_dataset):
    if not example['text']:
        preds.append(idx)
        continue
    pred = classifier(''+example['text'],new_classes, hypothesis_template = template)['labels'][0]
    idx = label2idx[pred]
    preds.append(idx)

print(classification_report(test_dataset['label'][:len(preds)], preds, target_names=classes, digits=4))