## Prompt Tuning 

In [None]:
from collections import Counter

In [None]:
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator, get_linear_schedule_with_warmup
from tqdm import tqdm
import wandb

In [None]:
wandb.init(project="prompt_learning_methods", name="prompt_tuning")

In [None]:
seed = 42
device = "cuda"
model_name_or_path = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer_name_or_path = "meta-llama/Llama-3.2-3B-Instruct"
dataset_name = "twitter_complaints"
text_column = "Tweet text"
label_column = "text_label"
max_length = 64
lr = 1e-4
num_epochs = 10
batch_size = 8
set_seed(seed)

### Dataset Preparation

In [None]:
dataset = load_dataset(path="ought/raft", name=dataset_name)

In [None]:
classes = [k.replace("_", " ") for k in dataset["train"].features["Label"].names]
print(classes)

In [None]:
dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["Label"]]},
    batched=True,
    num_proc=1,
)
print(dataset)

In [None]:
dataset["train"][0]

In [None]:
Counter(dataset["train"]["Label"])

### Preprocess the dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) #token=hf_token

In [None]:
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
target_max_length = max([len(tokenizer(class_label)["input_ids"]) for class_label in classes])
print(f"{target_max_length=}")

In [None]:
dataset.data['train']['Tweet text']