In [63]:
# !pip install -q transformers peft accelerate torch tqdm
# !pip install "datasets<3.0.0"

In [64]:
import os
from transformers import (
  AutoModelForCausalLM,
  AutoTokenizer,
  default_data_collator,
  get_linear_schedule_with_warmup,
)
from peft import (
  get_peft_model,
  PromptTuningInit,
  PromptTuningConfig,
  TaskType,
)
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [65]:
# ----------------------------
# Config and hyperparameters
# ----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name_or_path = "bigscience/bloomz-560m"
tokenizer_name_or_path = "bigscience/bloomz-560m"

dataset_name = "twitter_complaints"   # subset inside ought/raft
text_column = "Tweet text"
label_column = "text_label"

max_length = 64
lr = 3e-2
num_epochs = 5
batch_size = 8

In [66]:
# ----------------------------
# Load dataset
# ----------------------------
dataset = load_dataset("ought/raft", dataset_name)
print("Raw dataset example:")
print(dataset["train"][0])


Raw dataset example:
{'Tweet text': '@HMRCcustomers No this is my first job', 'ID': 0, 'Label': 2}


In [82]:
# Create human readable class labels
classes = [
    label.replace("_", " ")
    for label in dataset["train"].features["Label"].names
]


# Map numeric label -> text label
dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["Label"]]},
    batched=True,
    num_proc=1,
)
print("\nWith text_label added:")
print(dataset["train"][0])


Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Map:   0%|          | 0/3399 [00:00<?, ? examples/s]


With text_label added:
{'Tweet text': '@HMRCcustomers No this is my first job', 'ID': 0, 'Label': 2, 'text_label': 'no complaint'}


In [83]:
# ----------------------------
# Load tokenizer
# ----------------------------
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

target_max_length = max(
    len(tokenizer(class_label)["input_ids"]) for class_label in classes
)
print(f"\nTarget label max token length: {target_max_length}")



Target label max token length: 3


In [84]:
# ----------------------------
# Helper: preprocess function
# ----------------------------
def preprocess_function(examples):
    batch_size = len(examples[text_column])
    # Build input template: "<column> : text Label : "
    inputs = [
        f"{text_column} : {x} Label : " for x in examples[text_column]
    ]
    targets = [str(x) for x in examples[label_column]]

    model_inputs = tokenizer(inputs)
    labels = tokenizer(targets)

    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i] + [tokenizer.pad_token_id]

        model_inputs["input_ids"][i] = sample_input_ids + label_input_ids

        # Mask out the input part in labels with -100 so loss is applied only on label tokens
        labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids

        model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])


    # 2. Left pad to max_length and convert to tensors
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i]

        pad_len = max_length - len(sample_input_ids)

        if pad_len < 0:
            # truncate if sequence is too long
            model_inputs["input_ids"][i] = torch.tensor(sample_input_ids[:max_length])
            model_inputs["attention_mask"][i] = torch.tensor(
                model_inputs["attention_mask"][i][:max_length]
            )
            labels["input_ids"][i] = torch.tensor(label_input_ids[:max_length])
        else:
            # left pad
            model_inputs["input_ids"][i] = torch.tensor(
                [tokenizer.pad_token_id] * pad_len + sample_input_ids
            )
            model_inputs["attention_mask"][i] = torch.tensor(
                [0] * pad_len + model_inputs["attention_mask"][i]
            )
            labels["input_ids"][i] = torch.tensor(
                [-100] * pad_len + label_input_ids
            )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [85]:
# ----------------------------
# Tokenize and preprocess
# ----------------------------
processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

Running tokenizer on dataset:   0%|          | 0/50 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/3399 [00:00<?, ? examples/s]

In [86]:
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["test"]


# eval_dataset = processed_datasets["train"]
# train_dataset = processed_datasets["test"]


train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True,
)
eval_dataloader = DataLoader(
    eval_dataset,
    collate_fn=default_data_collator,
    batch_size=batch_size,
    pin_memory=True,
)

In [87]:
# Prompt Tuning configuration
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=12,
    prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
    tokenizer_name_or_path=model_name_or_path,
)

checkpoint_name = (
    f"{dataset_name}_{model_name_or_path}"
    f"_{peft_config.peft_type}_{peft_config.task_type}_v1.pt".replace("/", "_")
)

In [88]:
# ----------------------------
# Load base model and wrap with PEFT prompt tuning
# ----------------------------
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model = model.to(device)


trainable params: 12,288 || all params: 559,226,880 || trainable%: 0.0022


In [89]:
# ----------------------------
# Optimizer and scheduler
# ----------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [90]:
# ----------------------------
# Training loop
# ----------------------------
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for step, batch in enumerate(tqdm(train_dataloader, desc=f"Training epoch {epoch}")):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)

    model.eval()
    eval_loss = 0
    eval_preds = []

    for step, batch in enumerate(tqdm(eval_dataloader, desc=f"Eval epoch {epoch}")):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()

        eval_preds.extend(
            tokenizer.batch_decode(
                torch.argmax(outputs.logits, -1).detach().cpu().numpy(),
                skip_special_tokens=True,
            )
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)

    print(
        f"epoch={epoch}: "
        f"train_ppl={train_ppl:.4f}, train_loss={train_epoch_loss:.4f}, "
        f"eval_ppl={eval_ppl:.4f}, eval_loss={eval_epoch_loss:.4f}"
    )


Training epoch 0: 100%|██████████| 7/7 [00:02<00:00,  2.64it/s]
Eval epoch 0: 100%|██████████| 425/425 [01:30<00:00,  4.68it/s]


epoch=0: train_ppl=1545284091904.0000, train_loss=28.0662, eval_ppl=7374.8608, eval_loss=8.9058


Training epoch 1: 100%|██████████| 7/7 [00:02<00:00,  2.68it/s]
Eval epoch 1: 100%|██████████| 425/425 [01:30<00:00,  4.70it/s]


epoch=1: train_ppl=3117.4548, train_loss=8.0448, eval_ppl=6772.8535, eval_loss=8.8207


Training epoch 2: 100%|██████████| 7/7 [00:02<00:00,  2.68it/s]
Eval epoch 2: 100%|██████████| 425/425 [01:30<00:00,  4.69it/s]


epoch=2: train_ppl=200.2994, train_loss=5.2998, eval_ppl=19976.0312, eval_loss=9.9023


Training epoch 3: 100%|██████████| 7/7 [00:02<00:00,  2.68it/s]
Eval epoch 3: 100%|██████████| 425/425 [01:30<00:00,  4.70it/s]


epoch=3: train_ppl=100.5085, train_loss=4.6102, eval_ppl=33714.8125, eval_loss=10.4257


Training epoch 4: 100%|██████████| 7/7 [00:02<00:00,  2.69it/s]
Eval epoch 4: 100%|██████████| 425/425 [01:30<00:00,  4.70it/s]

epoch=4: train_ppl=77.5200, train_loss=4.3505, eval_ppl=31537.6758, eval_loss=10.3589





In [92]:
# ----------------------------
# Save model and tokenizer
# ----------------------------
model_dir = "./models/PromptTunedPEFT"
os.makedirs(model_dir, exist_ok=True)

tokenizer.save_pretrained(model_dir)
model.save_pretrained(model_dir)

print(f"\nModel and tokenizer saved to {model_dir}")


Model and tokenizer saved to ./models/PromptTunedPEFT


In [93]:
# ----------------------------
# Inference example
# ----------------------------
test_text = (
    '@nationalgridus I have no water and the bill is current and paid. '
    'Can you do something about this?'
)

with torch.no_grad():
    prompt = f'{text_column} : {{"{test_text}"}} Label : '
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=5,
        eos_token_id=tokenizer.eos_token_id,
    )

    print("\nInference output:")
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))


Inference output:
['Tweet text : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"} Label : no complaint']


In [81]:
import torch

# Clear PyTorch CUDA cache
torch.cuda.empty_cache()

# Optionally reset the allocator
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()


import gc
gc.collect()
torch.cuda.empty_cache()
