# Labeled inputs generator

This notebooks generates the text files we are using to train our conditional poems generators. It only supports english as language, given that we only have labels in an english poems dataset.

At the end, two text files are generated in the current directory:
* A concatenation of all the poems of the training set, with name `all_poems.labeled.by_[cat].train.[lang].txt`
* A concatenation of all the poems of the validation set, with name `all_poems.labeled.by_[cat].valid.[lang].txt`

Each poem is labeled by prepending an additional verse with its category (form or topic).

There are two additional files that you need to preserve, which we just copy to the current directory, from an unconditional generator source data files:
* A JSON file with the formatting configuration chosen, with name `all_poems.[lang].conf.json`
* A csv file with the assignment of poems to splits, with name `all_poems.[lang].splits.csv`

It is because we want to have a comparable baseline that we copy the splits and format used to train an unconditional poems generator, that we had previously saved in our datasets repo.

Anyway, you can use a different file structure by defining your own `PoemsFileConfig` instance and passing it to `LabeledPoemsIOWriter.__init__` method.

Set `run_as_standalone_nb = True` if you are running this notebook outside of a clone of its repository (https://github.com/Poems-AI/AI.git). For example, in a Colab or Kaggle notebook.

In [None]:
run_as_standalone_nb = False


from pathlib import Path


if run_as_standalone_nb:
    import sys    
    root_lib_path = Path('AI').resolve()
    if not root_lib_path.exists():
        !git clone https://github.com/Poems-AI/AI.git
    if str(root_lib_path) not in sys.path:
        sys.path.insert(0, str(root_lib_path))
        
    !pip install -r {root_lib_path/'requirements.txt'}
else:
    import local_lib_import

In [None]:
import pandas as pd
from poemsai.data import (DataSource, get_ds_root_placeholder,  LabeledPoemsSplitsDfReader, 
                          LabeledPoemsIOWriter, LabelsEstimator, LabelsType, label_type_to_str, 
                          LabelsWriterStd, LabelsWriterKeyValue, LabelsWriterKeyValueMultiverse, 
                          LabelsWriterExplained, PoemsFileConfig, PoemsSplitsDfContentReader)

Clone our datasets repo:

In [None]:
!git clone https://github.com/Poems-AI/dataset.git

## Labels selection

Choose: if you want to label the poems with its form (`LabelsType.Forms`) or by topic (`LabelsType.Topics`). 

Note that the poems chosen are different depending on the label type unless you choose `LabelsType.All`.

In [None]:
label_with = LabelsType.All

Set `fill_missing_labels = True` if you want the unknown labels to be estimated by classifiers. It only has an effect if `label_with` is `LabelsType.All`.

In [None]:
fill_missing_labels = False

## Text files generation

In [None]:
data_path = Path('dataset/all.txt/en.txt/only_end_tags/')
splits_df_path = data_path/'all_poems.en.splits.csv'
splits_df = pd.read_csv(splits_df_path, index_col=0)

If outside of Kaggle, you should set `kaggle_ds_root` to the root folder that contains the poems dataset
by Kaggle user michaelarman (https://www.kaggle.com/michaelarman/poemsdataset)

In [None]:
kaggle_ds_root = '/kaggle/input'
own_ds_root = './dataset'
kaggle_ds_root_placeholder = get_ds_root_placeholder(DataSource.Kaggle)
own_ds_root_placeholder = get_ds_root_placeholder(DataSource.Marcos)


def filter_df(df, labels_type, split):
    labels_type_filter = f'/{labels_type.value}/' if labels_type != LabelsType.All else ''
    kaggle_filter = kaggle_ds_root_placeholder if labels_type != LabelsType.All else ''
    return df.copy()[
        df.Location.str.contains(labels_type_filter, regex=False)
        & df.Location.str.contains(kaggle_filter, regex=False)
        & (df.Split == split)        
    ]


def replace_location_placeholder(df):
    df.Location = df.Location.str.replace(kaggle_ds_root_placeholder, 
                                          kaggle_ds_root,
                                          regex=False)
    df.Location = df.Location.str.replace(own_ds_root_placeholder, 
                                          own_ds_root,
                                          regex=False)
    return df

    
kaggle_ds_train_split_df = replace_location_placeholder(filter_df(splits_df, label_with, 'Train'))
kaggle_ds_valid_split_df = replace_location_placeholder(filter_df(splits_df, label_with, 'Validation'))
kaggle_ds_train_split_df, kaggle_ds_valid_split_df

Choose the proper labels writer class depending on how you want the poems to be tagged with the labels:
- `LabelsWriterStd`: include a verse for every label, '?' if the label for a category isn't available. For instance:
    form: ? \n<br>
    topic: beach \n <br>
    Verse 1 \n
- `LabelsWriterKeyValue`: include one verse with all the labels, with "key: value" format, '?' if the label for a category isn't available. For instance:<br>
    form: sonnet, topic: love \n<br>
    Verse 1 \n
- `LabelsWriterKeyValueMultiverse`: include a verse for every label, with "key: value" format, '?' if the label for a category isn't available. For instance:<br>
    form: sonnet \n<br>
    topic: love \n<br>
    Verse 1 \n
- `LabelsWriterExplained`: include a verse with a description of the labels, '?' if the label for a category isn't available. For instance:<br>
    This is a poem with sonnet form about love: \n<br>
    Verse 1 \n
- `LabelsWriterExplained(omit_empty=True)`: include a verse with a description of the labels, not including anything for categories not available. For instance (assume form is not available):<br>
    This is a poem about love: \n<br>
    Verse 1 \n

In [None]:
labels_writer = LabelsWriterExplained()

In [None]:
file_conf_path = data_path/'all_poems.en.conf.json'
labels_type_str = label_type_to_str(label_with)


def get_labels_writer_desc(labels_writer):
    if isinstance(labels_writer, LabelsWriterStd):
        return ''
    if isinstance(labels_writer, LabelsWriterKeyValue):
        return '_kv'
    if isinstance(labels_writer, LabelsWriterKeyValueMultiverse):
        return '_kv_mv'
    if isinstance(labels_writer, LabelsWriterExplained):
        return '_exp' if not labels_writer.omit_empty else '_exp_s'


labels_estimator = LabelsEstimator('gpt2', 'YOUR_HF_USER', 'YOUR_HF_ACCESS_TOKEN') if fill_missing_labels else None
poem_content_reader = PoemsSplitsDfContentReader()


def label_func_multi(location:str):
    labels = dict()
    for cat in LabelsType:
        if cat == LabelsType.All: continue
        cat_str = label_type_to_str(cat)
        if f'/{cat.value}/' in location:
            labels[cat_str] = Path(location).parent.name  
        elif fill_missing_labels:
            poem_lines = poem_content_reader.extract_poem_lines(location)
            labels[cat_str] = labels_estimator.predict(cat, poem_lines) if len(poem_lines) > 0 else ''
        else:
            labels[cat_str] = ''
    return labels


label_func = label_func_multi if label_with == LabelsType.All else None

readers = [
    LabeledPoemsSplitsDfReader(df, label_func=label_func) 
    for df in (kaggle_ds_train_split_df, kaggle_ds_valid_split_df)
]
split_names = ['train', 'valid']
lw_desc = get_labels_writer_desc(labels_writer)
fill_flag = '_filled' if fill_missing_labels else ''

for split_name, reader in zip(split_names, readers):
    labeled_poems_file_path = f'./all_poems.labeled.by_{labels_type_str}{lw_desc}{fill_flag}.{split_name}.en.txt'
    with open(labeled_poems_file_path, 'w', encoding='utf-8') as out_file:
        writer = LabeledPoemsIOWriter(
            out_file, 
            PoemsFileConfig.from_json(file_conf_path),
            labels_writer=labels_writer,
        )
        for labeled_poem in reader:
            writer.write_poem(labeled_poem)

!cp $splits_df_path .
!cp $file_conf_path .