In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"         # single GPU, avoids DataParallel
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import polars as pl
from datasets import Dataset
import torch
import gc
import random
import emoji

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    set_seed,
    EarlyStoppingCallback
)

from peft import (
    PromptTuningConfig,
    get_peft_model,
    TaskType,
    PeftModel
)

In [None]:
from huggingface_hub import login

# Insert your token here
login(token="")

In [3]:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [4]:
df = pl.read_csv('hf://datasets/bitext/Bitext-customer-support-llm-chatbot-training-dataset/Bitext_Sample_Customer_Support_Training_Dataset_27K_responses-v11.csv')

# Data cleaning
# fill in null
df = df.with_columns([
    pl.col("instruction")
      .cast(str)
      .str.to_lowercase()
      .fill_null(""),
    
    pl.col("response")
      .cast(str)
      .str.to_lowercase()
      .fill_null(""),
    
    pl.col("intent")
      .cast(str)
      .fill_null("unknown")
])

# Remove emoji
def remove_emojis(text: str) -> str:
    return emoji.replace_emoji(text, replace="")  # Remove all emojis safely

# Apply to instruction and response
df = df.with_columns([
    pl.col("instruction").map_elements(remove_emojis).alias("instruction")
])

df = df.with_columns([
    pl.col("response").map_elements(remove_emojis).alias("response")
])

#Exclude noisy flags
# Filter out rows where 'flags' contains Z, Q, or W ===
flag = ["flags"]
df_z = df.filter(
    ~pl.col(flag).cast(str).str.contains("Z")
)

df_zw = df.filter(
    ~pl.col(flag).cast(str).str.contains("Z") &
    ~pl.col(flag).cast(str).str.contains("W")
)

df_clean = df.filter(
    ~pl.col(flag).cast(str).str.contains("Z") &
    ~pl.col(flag).cast(str).str.contains("W") &
    ~pl.col(flag).cast(str).str.contains("Q")
)

print(df_z.height)
print(df_zw.height)
print(df_clean.height)

category_counts = (
    df_clean
    .group_by("category")
    .agg(pl.count().alias("counts"))
    .sort("counts", descending=True)
)

print(category_counts)

# Filter only selected categories
selected_categories = ["ORDER", "REFUND", "SHIPPING", "DELIVERY"]
df_selected = df_clean.filter(
    pl.col("category").is_in(selected_categories)
)


# Split dataset by category
# === Configuration ===
LABEL_COL = "category"  # 🔁 Replace with "intent" or any stratification column
SPLIT_RATIO_TRAIN = 0.7
SPLIT_RATIO_VAL = 0.15
SEED = 123
df_final = df_selected.clone()

# === Stratified split logic ===
random.seed(SEED)
train_parts = []
test_parts = []
val_parts = []

for label in df_final[LABEL_COL].unique().to_list():
    group_df = df_final.filter(pl.col(LABEL_COL) == label)
    group_df = group_df.sample(n=len(group_df), shuffle=True, seed=SEED)

    n = len(group_df)
    train_idx = int(n * SPLIT_RATIO_TRAIN)
    val_idx = int(n * (SPLIT_RATIO_TRAIN + SPLIT_RATIO_VAL))

    train_parts.append(group_df[:train_idx])
    val_parts.append(group_df[train_idx:val_idx])
    test_parts.append(group_df[val_idx:])

# === Combine all groups
train_df = pl.concat(train_parts).sort(["category", "instruction"])
val_df = pl.concat(val_parts).sort(["category", "instruction"])
test_df = pl.concat(test_parts).sort(["category", "instruction"])

print("✅ Split sizes:")
print(f"Train: {len(train_df)}")
print(f"Val:   {len(val_df)}")
print(f"Test:  {len(test_df)}")

print("\n📊 Category distribution in test set:")
print(test_df.select([pl.col(LABEL_COL)]).to_series().value_counts())

  df = df.with_columns([
  df = df.with_columns([


21586
20517
14454
shape: (11, 2)
┌──────────────┬────────┐
│ category     ┆ counts │
│ ---          ┆ ---    │
│ str          ┆ u32    │
╞══════════════╪════════╡
│ ACCOUNT      ┆ 3251   │
│ ORDER        ┆ 2152   │
│ REFUND       ┆ 1527   │
│ SHIPPING     ┆ 1156   │
│ DELIVERY     ┆ 1102   │
│ …            ┆ …      │
│ INVOICE      ┆ 1076   │
│ PAYMENT      ┆ 1028   │
│ FEEDBACK     ┆ 1004   │
│ CANCEL       ┆ 539    │
│ SUBSCRIPTION ┆ 537    │
└──────────────┴────────┘
✅ Split sizes:
Train: 4154
Val:   890
Test:  893

📊 Category distribution in test set:
shape: (4, 2)
┌──────────┬───────┐
│ category ┆ count │
│ ---      ┆ ---   │
│ str      ┆ u32   │
╞══════════╪═══════╡
│ DELIVERY ┆ 166   │
│ ORDER    ┆ 323   │
│ REFUND   ┆ 230   │
│ SHIPPING ┆ 174   │
└──────────┴───────┘


(Deprecated in version 0.20.5)
  .agg(pl.count().alias("counts"))


In [5]:
# ---------- sanity & speed ----------
assert torch.cuda.is_available(), "❌ CUDA is required for Prompt Tuning"
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(123)

# ---------- your models ----------
MODELS = {
    # "llama":    "meta-llama/Llama-3.2-1B",
    # "qwen":     "Qwen/Qwen3-0.6B-Base",
    "olmo":     "allenai/OLMo-2-0425-1B"
}
OUT_ROOT = "prompt-tune-outputs"
os.makedirs(OUT_ROOT, exist_ok=True)

# ---------- dataset builder ----------
def build_train_dataset(df: pl.DataFrame) -> Dataset:
    df_text = (
        df.select(["instruction", "response"])
          .drop_nulls(["instruction", "response"])
          .with_columns(
              pl.col("instruction").cast(pl.Utf8),
              pl.col("response").cast(pl.Utf8),
          )
          .with_columns(
              pl.format(
                  "You are a helpful retail assistant. Answer the following customer query briefly and accurately.\n\nCustomer: {}\nAnswer: {}",
                  pl.col("instruction"),
                  pl.col("response"),
                #   pl.lit(tokenizer.eos_token)   # <-- add eos
              ).alias("text")
          )
          .select(["text"])
    )
    return Dataset.from_dict({"text": df_text["text"].to_list()})


train_dataset = build_train_dataset(train_df)
val_dataset   = build_train_dataset(val_df) if ('val_df' in globals() and val_df is not None) else None

# ---------- trainer settings ----------
def make_training_args(out_dir: str, supports_bf16: bool) -> TrainingArguments:
    return TrainingArguments(
        output_dir=out_dir,
        num_train_epochs=50,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=32,
        learning_rate=5e-4,                  # often higher than LoRA
        logging_strategy="epoch",
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        prediction_loss_only=True,
        bf16=supports_bf16,
        fp16=not supports_bf16,
        gradient_checkpointing=True,
        remove_unused_columns=False,
        report_to=[],
    )

def train_one_model(model_name: str, key: str):
    print(f"\n🚀 Training Prompt Tuning for [{key}] {model_name}")

    # bf16 support check
    dev = torch.cuda.current_device()
    supports_bf16 = torch.cuda.get_device_capability(dev)[0] >= 8

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"
    
    # base model (frozen)
    base = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation="sdpa",
        device_map="auto",
        torch_dtype=torch.bfloat16 if supports_bf16 else torch.float16,
        trust_remote_code=True,
    )
    base.config.use_cache = False  # needed with checkpointing

    # prepare prompt tuning config
    prompt_cfg = PromptTuningConfig(
        task_type=TaskType.CAUSAL_LM,
        num_virtual_tokens=30,       # size of learned soft prompt
        tokenizer_name_or_path=model_name,
    )
    peft_model = get_peft_model(base, prompt_cfg)
    peft_model.print_trainable_parameters()

    # trainer setup
    out_dir = os.path.join(OUT_ROOT, f"{model_name.split('/')[-1]}-faq")
    args = make_training_args(out_dir, supports_bf16)

    def _tok(batch):
        return tokenizer(batch["text"], truncation=True, max_length=1024)

    train_tok = train_dataset.map(_tok, batched=True, remove_columns=train_dataset.column_names)
    val_tok   = val_dataset.map(_tok, batched=True, remove_columns=val_dataset.column_names) if val_dataset else None

    collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    trainer = Trainer(
        model=peft_model,
        args=args,
        train_dataset=train_tok,
        eval_dataset=val_tok,
        data_collator=collator,
        callbacks=[EarlyStoppingCallback(
            early_stopping_patience=3,
            early_stopping_threshold=0.00001
        )],
    )

    # continue from previous training
    last_ckpt = None
    if os.path.isdir(out_dir):
        ckpts = [os.path.join(out_dir, d) for d in os.listdir(out_dir) if d.startswith("checkpoint")]
        if ckpts:
            last_ckpt = max(ckpts, key=os.path.getmtime)
            print(f"🔁 Resuming from: {last_ckpt}")

    # train & save
    # continue training
    trainer.train(resume_from_checkpoint=last_ckpt) 
    # trainer.train()
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)
    print(f"✅ Saved Prompt Tuning adapter to: {out_dir}")

    # save logs
    df_logs = pl.DataFrame(trainer.state.log_history)
    df_logs.write_csv(os.path.join(out_dir, "train_eval_log.csv"))

    # cleanup
    del trainer, peft_model, base, tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

# ---------- loop over models ----------
for key, model_name in MODELS.items():
    try:
        train_one_model(model_name, key)
    except Exception as e:
        print(f"❌ Skipping [{key}] due to error: {e}")
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        continue

print("\n🎉 All selected models prompt-tuned.")



🚀 Training Prompt Tuning for [olmo] allenai/OLMo-2-0425-1B


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 61,440 || all params: 1,484,978,176 || trainable%: 0.0041


Map:   0%|          | 0/4154 [00:00<?, ? examples/s]

Map:   0%|          | 0/890 [00:00<?, ? examples/s]

🔁 Resuming from: prompt-tune-outputs/OLMo-2-0425-1B-faq/checkpoint-1300


Epoch,Training Loss,Validation Loss
21,1.3818,1.338635
22,1.3747,1.332626
23,1.3701,1.326791
24,1.3643,1.322115
25,1.3591,1.317907
26,1.3547,1.313043
27,1.35,1.308115
28,1.3468,1.30408
29,1.343,1.300868
30,1.3399,1.296871


✅ Saved Prompt Tuning adapter to: prompt-tune-outputs/OLMo-2-0425-1B-faq

🎉 All selected models prompt-tuned.
