# Llama‑2‑7B‑Chat QLoRA SFT — Customer Service (English)

이 notebook은 meta-llama/Llama-2-7b-chat-hf 모델을 고객 지원(customer-support) 데이터셋으로 파인튜닝하여,
“환불은 어떻게 받을 수 있나요?” 같은 질문에 명확하고 단계적인 영어 답변을 제공하도록 학습합니다.

**Datasets (Hugging Face):**
- `argilla/customer_assistant`
- `argilla/synthetic-sft-customer-support-single-turn`
- `bitext/Bitext-customer-support-llm-chatbot-training-dataset`

**Training recipe:** TRL의 SFTTrainer와 **PEFT QLoRA(4비트)**를 사용하며,
채팅 형식의 프롬프트 템플릿으로 학습하되 assistant의 응답에만 손실(loss) 을 계산합니다.
마지막에는 빠른 추론(inference) 예시가 포함되어 있습니다.

In [8]:
%pip -q install --upgrade pip
%pip -q install \
  "transformers>=4.43,<4.47" \
  "accelerate>=0.33" \
  "datasets>=2.20" \
  "trl==0.9.6" \
  "peft>=0.12" \
  "einops" \
  "sentencepiece" \
  "hf_transfer"


Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [5]:
import sys, platform, subprocess

if platform.system() == "Linux":
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "bitsandbytes>=0.43"])
        print("bitsandbytes installed.")
    except subprocess.CalledProcessError as e:
        print("bitsandbytes install failed. Check CUDA/GPU runtime.")
else:
    print("Non-Linux detected — skipping bitsandbytes. (Use full precision or provide a compatible wheel.)")

# (Optional) avoid Accelerate interactive prompt later
try:
    from accelerate.utils import write_basic_config
    write_basic_config()
except Exception:
    pass
print("Done.")


✅ bitsandbytes installed.
Done.


In [1]:
import os, random, math, json, re
from dataclasses import dataclass
from typing import Dict, List, Any
from datasets import load_dataset, DatasetDict, concatenate_datasets
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
                          TrainingArguments, AutoConfig)
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, get_peft_model
import torch

print(torch.__version__)
print("CUDA available:", torch.cuda.is_available())

2.8.0+cu128
CUDA available: True


## Login to Hugging Face (required for Llama‑2)

In [2]:
# --- Hugging Face login helper ---
try:
    from huggingface_hub import login, whoami
    import os
    if os.environ.get("HF_TOKEN"):
        login(token=os.environ["HF_TOKEN"])
        try:
            print("HF whoami:", whoami())
        except Exception:
            pass
    else:
        print(" Set HF_TOKEN env var or run:")
        print("   from huggingface_hub import login; login(token='hf_...')")
except Exception as e:
    print("Hugging Face login not available:", e)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


HF whoami: {'type': 'user', 'id': '6900911d2cafc5f572673a1c', 'name': 'survd0404', 'fullname': 'Lee', 'isPro': False, 'avatarUrl': '/avatars/6738b5fc2f71a66ab3ca028ab9a5da26.svg', 'orgs': [], 'auth': {'type': 'access_token', 'accessToken': {'displayName': 'HF_TOKEN', 'role': 'fineGrained', 'createdAt': '2025-11-05T14:30:18.423Z', 'fineGrained': {'canReadGatedRepos': True, 'global': ['discussion.write', 'post.write'], 'scoped': [{'entity': {'_id': '6900911d2cafc5f572673a1c', 'type': 'user', 'name': 'survd0404'}, 'permissions': ['repo.content.read', 'repo.write', 'inference.serverless.write', 'inference.endpoints.infer.write', 'inference.endpoints.write', 'user.webhooks.read', 'user.webhooks.write', 'collection.read', 'collection.write', 'discussion.write']}]}}}}


## Config

In [3]:
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

DATASETS = [
    ("argilla/customer_assistant", {}),
    ("argilla/synthetic-sft-customer-support-single-turn", {}),
    ("bitext/Bitext-customer-support-llm-chatbot-training-dataset", {}),
]

OUTPUT_DIR = "out_llama2_cs_qlora"
SEED = 42

# QLoRA / training settings tuned for a 24GB GPU; reduce if you get OOM.
MAX_SEQ_LEN = 1024
PER_DEVICE_TRAIN_BS = 1
GRAD_ACCUM = 8
EPOCHS = 2
LR = 2e-4
WARMUP_RATIO = 0.03

# LoRA
lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Response template for loss masking
RESPONSE_TEMPLATE = "\n### Assistant:\n"

## Load tokenizer & quantized base model (4‑bit)

In [4]:
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True, token=os.getenv("HF_TOKEN"))
# Ensure pad token exists
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_cfg,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
    device_map="auto",
    token=os.getenv("HF_TOKEN"),
)

# Attach LoRA
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

trainable params: 39,976,960 || all params: 6,778,392,576 || trainable%: 0.5898


## Dataset normalizer
We’ll coerce each dataset into a common schema: `{user, assistant}` with optional `system`.

In [5]:
def extract_user_assistant(record: Dict[str, Any]) -> Dict[str, str]:
    """Heuristic field mapping across the three datasets.
    Returns keys: user, assistant, and optional system.
    """
    keys = set(record.keys())

    # Common patterns
    user_candidates = [
        "user", "question", "prompt", "instruction", "input", "request", "query", "messages_user", "customer", "customer_message"
    ]
    assistant_candidates = [
        "assistant", "answer", "response", "output", "messages_assistant", "agent", "agent_response"
    ]
    system_candidates = ["system", "context", "role", "scenario"]

    def pick(cands):
        for c in cands:
            if c in record and isinstance(record[c], str) and len(record[c].strip()) > 0:
                return record[c]
        # Look for nested message lists (e.g., [{"role":"user","content":...}, ...])
        for k in keys:
            v = record[k]
            if isinstance(v, list) and len(v)>0 and isinstance(v[0], dict) and "role" in v[0] and "content" in v[0]:
                # try to assemble conversation
                user_msg = None
                assistant_msg = None
                system_msg = None
                for m in v:
                    role = m.get("role","")
                    content = m.get("content","")
                    if role == "system" and not system_msg:
                        system_msg = content
                    if role in ("user","customer") and not user_msg:
                        user_msg = content
                    if role in ("assistant","agent") and not assistant_msg:
                        assistant_msg = content
                if cands is user_candidates and user_msg:
                    return user_msg
                if cands is assistant_candidates and assistant_msg:
                    return assistant_msg
                if cands is system_candidates and system_msg:
                    return system_msg
        return None

    user = pick(user_candidates)
    assistant = pick(assistant_candidates)
    system = pick(system_candidates) or "You are a helpful, step-by-step customer support assistant."

    if not user or not assistant:
        # Fallback: try to synthesize from any two text-like fields
        textish = [k for k in keys if isinstance(record[k], str) and len(record[k].strip())>0]
        if len(textish) >= 2 and not user:
            user = record[textish[0]]
        if len(textish) >= 2 and not assistant:
            assistant = record[textish[1]]

    if not user or not assistant:
        raise ValueError("Could not map record to user/assistant: keys=%s" % list(keys))

    return {"user": user.strip(), "assistant": assistant.strip(), "system": system.strip()}

def make_text(example: Dict[str, str]) -> str:
    # Simple, stable chat template compatible with response-only loss masking.
    # We place a clear response boundary: `\n### Assistant:\n`
    sys_ = example.get("system","You are a helpful, step-by-step customer support assistant.")
    usr = example["user"]
    asst = example["assistant"]
    return f"""### System:
{sys_}
### User:
{usr}
### Assistant:
{asst}"""

def preprocess_dataset(hf_id: str, hf_config: Dict[str, Any]):
    ds = load_dataset(hf_id, **hf_config)
    split_names = [k for k in ds.keys()] if isinstance(ds, DatasetDict) else ["train"]
    out_splits = []
    for split in split_names:
        d = ds[split] if isinstance(ds, DatasetDict) else ds
        # Filter/map to unified schema
        def _map_fn(rec):
            try:
                ex = extract_user_assistant(rec)
                return {"text": make_text(ex)}
            except Exception:
                return {"text": None}
        d2 = d.map(_map_fn, remove_columns=[c for c in d.column_names if c!="text"])
        d2 = d2.filter(lambda r: r["text"] is not None and len(r["text"])>20)
        out_splits.append(d2)
    # Concatenate all available splits
    merged = concatenate_datasets(out_splits) if len(out_splits)>1 else out_splits[0]
    return merged

all_ds = []
for name, cfg in DATASETS:
    print("Loading:", name, cfg)
    try:
        all_ds.append(preprocess_dataset(name, cfg))
    except Exception as e:
        print("!! Skipped", name, "due to", e)

assert len(all_ds) > 0, "No datasets loaded—check access or names."
merged = concatenate_datasets(all_ds)
merged = merged.shuffle(seed=SEED)

# Quick peek
for i in range(2):
    print("----")
    print(merged[i]["text"][:400])
print("\nTotal examples:", len(merged))

Loading: argilla/customer_assistant {}


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001-6853b30b1d9b88(…):   0%|          | 0.00/76.1k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/196 [00:00<?, ? examples/s]

Map:   0%|          | 0/196 [00:00<?, ? examples/s]

Filter:   0%|          | 0/196 [00:00<?, ? examples/s]

Loading: argilla/synthetic-sft-customer-support-single-turn {}


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/133k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Filter:   0%|          | 0/100 [00:00<?, ? examples/s]

Loading: bitext/Bitext-customer-support-llm-chatbot-training-dataset {}


README.md: 0.00B [00:00, ?B/s]

Bitext_Sample_Customer_Support_Training_(…):   0%|          | 0.00/19.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/26872 [00:00<?, ? examples/s]

Map:   0%|          | 0/26872 [00:00<?, ? examples/s]

Filter:   0%|          | 0/26872 [00:00<?, ? examples/s]

----
### System:
You are a helpful, step-by-step customer support assistant.
### User:
I want assistance retrieving my user profile key
### Assistant:
Thanks for dropping us a line to us for assistance with retrieving your user profile key. I understand how important it is for you to access your account and retrieve the key. To assist you, please follow these steps:

1. Visit the login page of our plat
----
### System:
You are a helpful, step-by-step customer support assistant.
### User:
where to inform of problems with online payments ?
### Assistant:
To inform us about any problems you encounter with online payments, you can reach out to our dedicated customer support team. They are available {{Customer Support Hours}} at {{Customer Support Phone Number}}, or you can chat with them through the Live

Total examples: 27168


## Tokenization + Loss Masking (Assistant-only)
We’ll mask loss to the tokens after the `RESPONSE_TEMPLATE` boundary.

In [10]:
# Build response template ids
response_template_ids = tokenizer.encode(RESPONSE_TEMPLATE, add_special_tokens=False)
collator = DataCollatorForCompletionOnlyLM(
    response_template=RESPONSE_TEMPLATE,
    tokenizer=tokenizer,
    mlm=False,
)

# No need for a separate tokenization function; SFTTrainer will handle packing & truncation.

## Train with TRL SFTTrainer (QLoRA)

In [16]:
import os, glob

ADAPTER_DIR = os.path.join(OUTPUT_DIR, "lora_adapter")
os.makedirs(ADAPTER_DIR, exist_ok=True)

# Save LoRA adapter (creates adapter_config.json + adapter_model.safetensors)
trainer.model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Adapter saved to:", ADAPTER_DIR)

# === merge LoRA into base (optional; needs VRAM) ===
try:
    from peft import PeftModel

    def find_adapter_dir(base_dir: str) -> str:
        cand = os.path.join(base_dir, "lora_adapter")
        if os.path.exists(os.path.join(cand, "adapter_config.json")):
            return cand
        ckpts = sorted(
            glob.glob(os.path.join(base_dir, "checkpoint-*")),
            key=lambda p: int(p.split("-")[-1]),
            reverse=True
        )
        for c in ckpts:
            for sub in ("", "lora_adapter"):
                d = os.path.join(c, sub)
                if os.path.exists(os.path.join(d, "adapter_config.json")):
                    return d
        raise FileNotFoundError(f"No adapter_config.json found under {base_dir}")

    adapter_dir = find_adapter_dir(OUTPUT_DIR)

    # Load base in full precision for merge
    base = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=None,
        torch_dtype=torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16,
        device_map="auto",
        token=os.getenv("HF_TOKEN"),
    )

    peft_model = PeftModel.from_pretrained(base, adapter_dir)
    merged_model = peft_model.merge_and_unload()  # -> plain HF model with LoRA merged

    merged_dir = os.path.join(OUTPUT_DIR, "merged")
    os.makedirs(merged_dir, exist_ok=True)
    merged_model.save_pretrained(merged_dir, safe_serialization=True)
    tokenizer.save_pretrained(merged_dir)
    print("Merged model saved to:", merged_dir)

except Exception as e:
    print("Merge skipped due to:", e)


Adapter saved to: out_llama2_cs_qlora/lora_adapter


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Merged model saved to: out_llama2_cs_qlora/merged


## (Optional) Merge LoRA into Base Weights and Save
Merging requires enough VRAM; you can skip and use the adapter in inference.

In [17]:
try:
    from peft import PeftModel
    base = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=None,  # load in full precision for merge
        torch_dtype=torch.float16,
        device_map="auto",
        token=os.getenv("HF_TOKEN"),
    )
    peft_model = PeftModel.from_pretrained(
        base,
        os.path.join(OUTPUT_DIR, "lora_adapter"),
    )
    merged_model = peft_model.merge_and_unload()
    merged_dir = os.path.join(OUTPUT_DIR, "merged")
    os.makedirs(merged_dir, exist_ok=True)
    merged_model.save_pretrained(merged_dir)
    tokenizer.save_pretrained(merged_dir)
    print("Merged model saved to:", merged_dir)
except Exception as e:
    print("Merge skipped due to:", e)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Merged model saved to: out_llama2_cs_qlora/merged


## Inference demo: “How can I get a refund?”

In [18]:
from transformers import pipeline

# Load with adapter for inference (safe VRAM)
inf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_cfg,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
    device_map="auto",
    token=os.getenv("HF_TOKEN"),
)
from peft import PeftModel
inf_model = PeftModel.from_pretrained(inf_model, os.path.join(OUTPUT_DIR, "lora_adapter"))
inf_model.eval()

pipe = pipeline("text-generation", model=inf_model, tokenizer=tokenizer, device_map="auto")

def build_prompt(user_question: str, system: str = "You are a helpful, step-by-step customer support assistant."):
    return f"""### System:
{system}
### User:
{user_question}
### Assistant:
"""

prompt = build_prompt("How can I get a refund?")
out = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id)
print(out[0]["generated_text"][len(prompt):].strip())

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausa

Great, thank you for reaching out to us! To initiate a refund, please follow these steps:

1. Check your order history on our website to ensure that the item you want a refund for is eligible for a return.
2. Contact our customer service team via email or phone to request a return. Please include your order number and the reason for the return.
3. Once we receive your return request, we will provide you with a return shipping label and instructions on how to return the item.
4. Once we receive the returned item, we will process your refund within 3-5 business days.

Please let me know if you have any other questions or concerns!
