# Evaluation of conditional generators

This notebooks serves as a tool to evaluate the conditioning quality of a text generator using a text classifier.

It assumes the generator is able to use the first word of an input sequence as a label that conditions the characteristics of the output text.

We construct the prompts fed to the generator in two different ways, which results in four different metrics, as we calculate the validation loss and accuracy for each procedure.

**Procedure A**

1. Generate text using just the labels plus a line break as prompt (one label per input sequence, as many sequences as labels).
2. Delete the labels from the generated text.
3. Feed the generated text to the corresponding classifier to evaluate the metrics.

**Procedure B**

1. Generate text using as prompt an initial substring of each sequence in the validation set, with its label and a line break prepended. We are limiting the length of prompt to the minimum between 100 characters and 1/4 of the length of the text.
2. Delete the labels from the generated text.
3. Feed the generated text to the corresponding classifier to evaluate the metrics.

The dataset/split used in this case is the same dataset that we use as validation set to train the text classifier (see [train_poems_classifier.ipynb](train_poems_classifier.ipynb))


Set `run_as_standalone_nb = True` if you are running this notebook outside of a clone of its repository (https://github.com/Poems-AI/AI.git). For example, in a Colab or Kaggle notebook.

In [None]:
!pip install -r  ../requirements.txt

In [None]:
run_as_standalone_nb = False


from pathlib import Path


if run_as_standalone_nb:
    import sys    
    root_lib_path = Path('AI').resolve()
    if not root_lib_path.exists():
        !git clone https://github.com/Poems-AI/AI.git
    if str(root_lib_path) not in sys.path:
        sys.path.insert(0, str(root_lib_path))
        
    !pip install happytransformer
    !pip install transformers
    !pip install "torch>=1.10"
    !apt-get install git-lfs
    !git lfs install
else:
    import local_lib_import

In [None]:
import os
import pandas as pd
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
import torch
import torch.nn.functional as F

In [None]:
from poemsai.config import set_config_value
from poemsai.data import (build_labeled_dfs_from_splits, label_type_to_str, LabelsType, LabelsWriterStd, 
                          PoemsFileConfig)
from poemsai.metrics import ConditionalGenEvaluator
from poemsai.nb_utils import download_checkpoint_from_hf_hub

Clone our datasets repo:

In [None]:
!git clone https://github.com/Poems-AI/dataset.git

Choose the type of labels the generator must be conditioned on:

In [None]:
labels_type = LabelsType.Topics

In [None]:
HF_USER = 'YOUR_HF_USER'
gen_model_name = 'gpt2-poems-endtags.en'
cat_name = label_type_to_str(labels_type)
clf_model_name = f'distilbert-poems-clf-by-{cat_name}'
hf_pwd = 'YOUR_HF_PASSWORD'
download_checkpoint_from_hf_hub(gen_model_name, HF_USER, hf_pwd)
download_checkpoint_from_hf_hub(clf_model_name, HF_USER, hf_pwd)
hf_pwd = None
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_name)
clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name)
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name)

You must use the same file config and the same `BaseLabelsWriter` subclass that you used to generate the text file that was used to train `gen_model_name`.

This way, the prompts fed to the generator will have the same format as the training dataset.

In [None]:
file_conf = PoemsFileConfig.from_json('dataset/all.txt/en.txt/only_end_tags/all_poems.en.conf.json')
all_cats_ordered = [label_type_to_str(cat) for cat in LabelsType if cat != LabelsType.All]
evaluator = ConditionalGenEvaluator(gen_model, gen_tokenizer, clf_model, clf_tokenizer, file_conf,
                                    cat_name, all_cats_ordered, labels_writer=LabelsWriterStd())

## Metrics A

These metrics are referred to as "*[label_type] conditional loss A*" and "*[label_type] conditional accuracy A*" in the [results doc](../docs/results.md), with "[label_type]" being one of {"topic", "form"}

In [None]:
evaluator.eval_with_labels_as_prompt()

## Metrics B

These metrics are referred to as "*[label_type] conditional loss B*" and "*[label_type] conditional accuracy B*" in the [results doc](../docs/results.md), with "[label_type]" being one of {"topic", "form"}

If outside of Kaggle, you should point `KAGGLE_DS_ROOT` to the root folder that contains the poems dataset
by Kaggle user michaelarman (https://www.kaggle.com/michaelarman/poemsdataset)

In [None]:
set_config_value('KAGGLE_DS_ROOT', '/kaggle/input')

In [None]:
splits_df_path = 'dataset/all.txt/en.txt/only_end_tags/all_poems.en.splits.csv'
splits_df = pd.read_csv(splits_df_path, index_col=0)
_, valid_df = build_labeled_dfs_from_splits(splits_df, labels_type)
evaluator.eval_with_seq_fragment_as_prompt(valid_df)