In [None]:
%pip install ruprompts[hydra]

In this tutorial we will train a prompt for style transfer task using the `ruprompts-train` entrypoint. As an example task we will take the [RUSSE 2022 detoxification competition](https://github.com/skoltech-nlp/russe_detox_2022). Let's download the data:

In [None]:
!wget https://raw.githubusercontent.com/skoltech-nlp/russe_detox_2022/main/data/input/train.tsv
!wget https://raw.githubusercontent.com/skoltech-nlp/russe_detox_2022/main/data/input/dev.tsv

The `ruprompts-train` entrypoint is a Hydra application. See the [corresponding section in docs](https://ai-forever.github.io/ru-prompts/hydra/) for details. Also see the section about [config structure](https://ai-forever.github.io/ru-prompts/hydra/config/) for available group options and parameters.

Here is a brief explanation of the parameters selected in the below cell:
  - `task=text2text` - selects the default model and preprocessing
  - `training.run_name=detox-russe` - the run name to be used if WandB is installed and enabled
  - `task.task_name=detox` - the task name used to group task workdirs
  - `prompt_provider=tensor` - TensorPromptProvider
  - `prompt_format.template='"<P*60>{toxic_comment}<P*20>"'` - our prompt will contain the prefix of 60 tokens and infix of 20 tokens
  - `training.learning_rate=0.1`, `scheduler=linear_schedule_with_warmup` - optimization parameters
  - `+dataset=from_tsv` - selects an option for loading the dataset from local TSV files
  - `dataset.data_files.train=/content/train.tsv` - defines the training file
  - `dataset.data_files.validation=/content/dev.tsv` - defines the validation file
  - `preprocessing.target_field=neutral_comment1` - selects the dataset field to be used as target
  - `preprocessing.max_tokens=1000` - defines the max length of training sequences in tokens, including the target sequence
  - `preprocessing.truncation_field=toxic_comment` - defines the field to be truncated if the training sequence exceeds `max_tokens`
  - `training.report_to=tensorboard` - reports logs locally
  - callbacks: 
    - `freeze_transformer_unfreeze_prompt` - makes only prompt provider's parameters trainable
    - `reduce_checkpoint` - reduces the checkpoint size after each saving by leaving only prompt provider's weights there
    - `save_pretrained_prompt` - saves the prompt as pretrained in each checkpoint folder

In [None]:
!ruprompts-train \
    task=text2text \
    training.run_name=detox-russe \
    task.task_name=detox \
    prompt_provider=tensor \
    prompt_format.template='"<P*60>{toxic_comment}<P*20>"' \
    training.learning_rate=0.1 \
    scheduler=linear_schedule_with_warmup \
    +dataset=from_tsv \
    dataset.data_files.train=/content/train.tsv \
    dataset.data_files.validation=/content/dev.tsv \
    preprocessing.target_field=neutral_comment1 \
    preprocessing.truncation_field=toxic_comment \
    preprocessing.max_tokens=1000 \
    callbacks=[freeze_transformer_unfreeze_prompt,reduce_checkpoint,save_pretrained_prompt] \
    training.report_to=tensorboard

If the `save_pretrained_prompt` callback is selected, you can load the current prompt from any checkpoint by running something like

In [None]:
from ruprompts import Prompt

prompt = Prompt.from_pretrained("./outputs/debug/detox/20211231_235959/checkpoint-1500")

The prompt can then be used for inference just as that loaded from HF Hub.

To upload the trained prompt to hub, use the `push_to_hub` method:

In [None]:
prompt.push_to_hub("prompt_backbone_task")