In [None]:
from datasets import load_dataset

from transformers import GPTNeoXForCausalLM, AutoTokenizer

from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments


from peft import (
    get_peft_config,
    get_peft_model,
    PromptTuningInit,
    PromptTuningConfig,
    TaskType,
    PeftType,
)


: 

TODO
- add transformers to deps
- torch
- check padding situation
- change `evaluation_strategy` (or pass as config)
- pydantic

In [None]:
MAX_LENGTH = 64

In [None]:
model = GPTNeoXForCausalLM.from_pretrained(
    "EleutherAI/pythia-70m-deduped",
    revision="step3000",
    # cache_dir="./pythia-70m-deduped/step3000",
)

tokenizer = AutoTokenizer.from_pretrained(
    "EleutherAI/pythia-70m-deduped",
    revision="step3000",
    # cache_dir="./pythia-70m-deduped/step3000",
    # truncation=True,
    # padding="max_length",
    # max_length=MAX_LENGTH,
)
tokenizer.add_special_tokens({"pad_token": "<|padding|>"});

In [None]:
PromptTuningInit.TEXT

In [None]:
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=4,
    prompt_tuning_init_text=" ",
    tokenizer_name_or_path="EleutherAI/pythia-70m-deduped",
)


In [None]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


In [None]:
dataset = load_dataset("boolq")


In [None]:
def tokenize_function(example, tokenizer):
    all_text = f"Passage:\n{example['passage']} \nQuestion:\n{example['question']}\nAnswer:\n{example['answer']}"
    return tokenizer(all_text)


In [None]:
print(
    tokenizer.decode(
        tokenize_function(dataset["train"][0], tokenizer)["input_ids"]
    )
)


In [None]:
tokenized = dataset.map(
    tokenize_function,
    fn_kwargs={"tokenizer": tokenizer},
    remove_columns=["question", "answer", "passage"],
)
train_tokenized = tokenized["train"].select(range(100))
val_tokenized = tokenized["validation"].select(range(100))

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=64)

# For debugging:
train_dataloader = DataLoader(
    train_tokenized, shuffle=True, batch_size=8, collate_fn=data_collator
)
for batch in train_dataloader:
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    print(input_ids.shape)

In [None]:
# tokenizer.decode(input_ids[0])

In [None]:
training_args = TrainingArguments(
    output_dir="/tmp",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=2,
    push_to_hub=False,
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
    #     compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
    #     preprocess_logits_for_metrics=preprocess_logits_for_metrics
    #     if training_args.do_eval and not is_torch_tpu_available()
    #     else None,
)


In [None]:
trainer.train()
