In [None]:
from peft import PromptTuningConfig, PromptTuningInit, TaskType, get_peft_model
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    GPTNeoXForCausalLM,
    Trainer,
    TrainingArguments
)
from torch.utils.data import DataLoader
import datasets

In [None]:
dataset = datasets.load_from_disk("/fsx/proj-chemnlp/data/EleutherAI/pythia-160m/marianna13/chemrxiv")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    pretrained_model_name_or_path="EleutherAI/pythia-160m",
    revision="main",
)
tokenizer.add_special_tokens({"pad_token": "<|padding|>"})

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
dl = DataLoader(dataset, batch_size=2, collate_fn=data_collator)

In [None]:
model = GPTNeoXForCausalLM.from_pretrained(
    pretrained_model_name_or_path="EleutherAI/pythia-160m",
    revision="main",
)

In [None]:
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=10,
    prompt_tuning_init_text=" ",
    tokenizer_name_or_path="EleutherAI/pythia-160m",
)
peft_model = get_peft_model(model, peft_config)

In [None]:
for b in dl:
    token_types = b.pop('token_type_ids')
    break

In [None]:
out = model(**b)
out['loss']

In [None]:
out = peft_model(**b)
out['loss']

In [None]:
all([p.requires_grad for p in model.parameters()]), all([p.requires_grad for p in peft_model.parameters()])

In [None]:
model

In [None]:
split_dataset = dataset.train_test_split(test_size=0.025)

In [None]:
training_args = TrainingArguments(
        output_dir='./',
        num_train_epochs=1,
        learning_rate=3e-4,
        evaluation_strategy='steps',
        logging_steps=1,
        eval_steps=50,
        dataloader_num_workers=4,
        bf16=True,
        fp16=False,
        per_device_train_batch_size=28,
        per_device_eval_batch_size=28,
        report_to="none",
    )

In [None]:
trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=split_dataset["train"],
        eval_dataset=split_dataset["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

In [None]:
trainer.train()