In [None]:
from datasets import load_dataset
from transformers import set_seed, AutoModelForSeq2SeqLM, AutoTokenizer
from peft import get_peft_model, MultitaskPromptTuningConfig, TaskType, MultitaskPromptTuningInit

set_seed(42)

model_name = "google/flan-t5-base"

peft_config = MultitaskPromptTuningConfig(
    tokenizer_name_or_path=model_name,
    num_tasks=2,
    task_type=TaskType.SEQ_2_SEQ_LM,
    prompt_tuning_init=MultitaskPromptTuningInit.TEXT,
    num_virtual_tokens=50,
    num_transformer_submodules=1,
    prompt_tuning_init_text="classify the following into either positive or negative, or entailment, neutral or contradiction:",
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = get_peft_model(model, peft_config)

model = model.cuda()