In [None]:
%pip install ruprompts[hydra]

In [None]:
!wget https://raw.githubusercontent.com/skoltech-nlp/russe_detox_2022/main/data/input/train.tsv
!wget https://raw.githubusercontent.com/skoltech-nlp/russe_detox_2022/main/data/input/dev.tsv

In [None]:
import pandas as pd

df = pd.read_csv("train.tsv", sep="\t")
df.drop(["index"], axis=1, inplace=True)
df.to_csv("train.tsv", index=False, sep="\t")

## Training

In [None]:
!CUDA_VISIBLE_DEVICES=0 ruprompts-train \
    task=text2text \
    training.run_name=detox-russe-tensor-linear-lr-1e-1-bs-2-length-120 \
    task.task_name=detox-russe \
    prompt_provider=tensor \
    prompt_format.template='"<P*100>{toxic_comment}<P*20>"' \
    training.learning_rate=1e-1 \
    training.per_device_train_batch_size=2 \
    training.max_steps=100000 \
    scheduler=linear_schedule_with_warmup \
    +dataset=from_tsv \
    dataset.data_files.train=./train.tsv \
    dataset.data_files.validation=./dev.tsv \
    preprocessing.target_field=neutral_comment1 \
    preprocessing.truncation_field=toxic_comment \
    preprocessing.max_tokens=1792 \
    callbacks=[freeze_transformer_unfreeze_prompt,reduce_checkpoint,save_pretrained_prompt] \
    training.report_to=tensorboard \
    hydra.run.dir="."

## Inference

In [None]:
from ruprompts import Prompt
from transformers import pipeline, GPT2LMHeadModel, AutoTokenizer

backbone_id = "ai-forever/rugpt3large_based_on_gpt2"

prompt = Prompt.from_pretrained("./checkpoint-100000")
model = GPT2LMHeadModel.from_pretrained(backbone_id)
tokenizer = AutoTokenizer.from_pretrained(backbone_id)

ppln = pipeline("text2text-generation-with-prompt", prompt=prompt, model=model, tokenizer=tokenizer, device=0)

In [None]:
from datasets import load_dataset

dataset_dict = load_dataset("csv", data_files={"train": "train.tsv", "validation": "dev.tsv"}, sep="\t")
valid_dataset = dataset_dict["validation"]

In [None]:
from tqdm import tqdm
import transformers
transformers.logging.set_verbosity_error()

beam_count = 10

predictions = []
    
for i in tqdm(valid_dataset["toxic_comment"]):
    options = ppln(
        {"toxic_comment": i},
        do_sample=False,
        num_beams=beam_count,
        num_return_sequences=beam_count,
    )

    options = [i["generated_text"].replace("<pad>", "") for i in options]
    answer = sorted(options, key=len)[-1]  # get longest answer
    predictions.append(answer)

with open("subm.txt", "w") as f:
    f.writelines(list(map(lambda x: x.replace("\n", " ") + "\n", predictions)))