
# SFT with Template Masking (LoRA) â€” Argilla/Bitext + Dolly (7B)

This notebook fine-tunes a chat model (e.g., `meta-llama/Llama-2-7b-hf`) using **template masking** (no collator), so only the **assistant** parts are trained.  
Key points:

- **No DataCollator** required. We use the tokenizer's **chat template** with `{% generation %}...{% endgeneration %}` to mask assistant turns.
- `packing=True` works (higher throughput) since masking is handled by the chat template.
- `tokenizer.truncation_side="left"` to preserve the assistant turn at the end of long sequences.
- Supports Argilla/Bitext customer-support datasets + Dolly-15k subset.
- Uses **LoRA** (PEFT) to reduce VRAM usage.


In [2]:

# %%capture
# # (Optional) Install versions known to work well together
!pip install -U transformers accelerate datasets peft bitsandbytes
!pip install -U trl>=0.25.0
# # (Optional) flash-attn (environment dependent)
# !pip install flash-attn --no-build-isolation




In [2]:

from typing import Dict, Any, List
from datasets import load_dataset, DatasetDict, concatenate_datasets
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM

import random
import os

SEED = 42
random.seed(SEED)


In [3]:

# Choose your base model (7B chat model recommended)
MODEL_NAME = os.environ.get("MODEL_NAME", "meta-llama/Llama-2-7b-hf")
# You must have access to the model on HF. If needed: `huggingface-cli login`

print("Loading tokenizer:", MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
# Truncation/padding so that assistant turn (at the end) is kept
tokenizer.truncation_side = "left"
tokenizer.padding_side    = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("Loading model:", MODEL_NAME)
# Use bf16 if on multi-GPU auto mapping; fp16 if single 24GB GPU for lower VRAM
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",     # change to torch.bfloat16 / torch.float16 explicitly if desired
    device_map="auto"       # let HF place layers automatically if multiple GPUs
)
# Disable cache during training
if hasattr(model, "config"):
    model.config.use_cache = False

print("Model & tokenizer ready.")


Loading tokenizer: meta-llama/Llama-2-7b-hf


`torch_dtype` is deprecated! Use `dtype` instead!


Loading model: meta-llama/Llama-2-7b-hf


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model & tokenizer ready.


In [4]:

# Define a chat template that masks assistant turns only.
# With this, SFTTrainer will compute loss only on `{% generation %}` blocks.
tokenizer.chat_template = """{% for message in messages %}
<|{{ message['role'] }}|>
{% if message['role'] == 'assistant' -%}
{% generation %}
{{ message['content'] }}
{% endgeneration %}
{%- else -%}
{{ message['content'] }}
{%- endif %}

{% endfor %}"""

print("Chat template with assistant masking is set.")


Chat template with assistant masking is set.


In [5]:

# Datasets to pull
ARGILLA_BITEXT = [
    ("argilla/customer_assistant", {}),
    ("argilla/synthetic-sft-customer-support-single-turn", {}),
    ("bitext/Bitext-customer-support-llm-chatbot-training-dataset", {}),
]
DOLLY_SLICE = "train[:1000]"  # Adjust as needed

def extract_user_assistant(record: Dict[str, Any]) -> Dict[str, str]:
    """Map heterogeneous keys to a unified (system, user, assistant) triple."""
    keys = set(record.keys())

    user_candidates = [
        "user","question","prompt","instruction","input","request","query",
        "messages_user","customer","customer_message"
    ]
    assistant_candidates = [
        "assistant","answer","response","output","messages_assistant",
        "agent","agent_response"
    ]
    system_candidates = ["system","context","role","scenario"]

    def pick(cands):
        # 1) flat keys
        for c in cands:
            if c in record and isinstance(record[c], str) and len(record[c].strip())>0:
                return record[c]
        # 2) nested messages
        for k in keys:
            v = record[k]
            if isinstance(v, list) and len(v)>0 and isinstance(v[0], dict) and "role" in v[0] and "content" in v[0]:
                u, a, s = None, None, None
                for m in v:
                    role = (m.get("role","") or "").lower()
                    content = m.get("content","")
                    if role == "system" and not s:
                        s = content
                    if role in ("user","customer") and not u:
                        u = content
                    if role in ("assistant","agent") and not a:
                        a = content
                if cands is user_candidates and u:
                    return u
                if cands is assistant_candidates and a:
                    return a
                if cands is system_candidates and s:
                    return s
        return None

    user = pick(user_candidates)
    assistant = pick(assistant_candidates)
    system = pick(system_candidates) or "You are a helpful, step-by-step customer support assistant."

    # fallback: synthesize from any two text fields
    if not user or not assistant:
        textish = [k for k in keys if isinstance(record.get(k), str) and len(record[k].strip())>0]
        if len(textish)>=2 and not user:
            user = record[textish[0]]
        if len(textish)>=2 and not assistant:
            assistant = record[textish[1]]

    if not user or not assistant:
        raise ValueError("Could not map record to user/assistant")

    return {"system": system.strip(), "user": user.strip(), "assistant": assistant.strip()}


In [6]:

def preprocess_to_messages(hf_id: str, hf_config: Dict[str, Any]):
    ds = load_dataset(hf_id, **hf_config)
    split_names = list(ds.keys()) if isinstance(ds, DatasetDict) else ["train"]
    out_splits = []

    for split in split_names:
        d = ds[split] if isinstance(ds, DatasetDict) else ds
        def _map(rec):
            try:
                ex = extract_user_assistant(rec)
                msgs = []
                if ex.get("system"):
                    msgs.append({"role":"system","content":ex["system"]})
                msgs.append({"role":"user","content":ex["user"]})
                msgs.append({"role":"assistant","content":ex["assistant"]})
                return {"messages": msgs}
            except Exception:
                return {"messages": None}

        d2 = d.map(_map, remove_columns=[c for c in d.column_names if c!="messages"])
        d2 = d2.filter(lambda r: r["messages"] is not None)
        out_splits.append(d2)

    merged = concatenate_datasets(out_splits) if len(out_splits)>1 else out_splits[0]
    return merged

def build_argilla_bitext_messages():
    all_ds = []
    for name, cfg in ARGILLA_BITEXT:
        print("Loading:", name, cfg)
        try:
            all_ds.append(preprocess_to_messages(name, cfg))
        except Exception as e:
            print("!! Skipped", name, "due to", e)
    if not all_ds:
        raise RuntimeError("No argilla/bitext datasets loaded â€” check access/names.")
    merged = concatenate_datasets(all_ds)
    return merged


In [7]:

def build_dolly_messages(slice_spec=DOLLY_SLICE):
    ds = load_dataset("databricks/databricks-dolly-15k", split=slice_spec)
    def _fmt(sample):
        instruction = sample.get("instruction","") or ""
        context = sample.get("context","") or ""
        response = sample.get("response","") or ""
        user = f"{instruction.strip()} {context.strip()}".strip()
        assistant = response.strip()
        msgs = []
        # Dolly usually has no system by default
        if user:
            msgs.append({"role":"user","content":user})
        if assistant:
            msgs.append({"role":"assistant","content":assistant})
        return {"messages": msgs if msgs else None}

    ds = ds.map(_fmt, remove_columns=[c for c in ds.column_names if c!="messages"])
    ds = ds.filter(lambda r: r["messages"] is not None)
    return ds


In [8]:

argilla_bitext = build_argilla_bitext_messages()
dolly_subset   = build_dolly_messages(DOLLY_SLICE)

dataset = concatenate_datasets([argilla_bitext, dolly_subset]).shuffle(seed=SEED)
print("Merged size (raw):", len(dataset))

def has_assistant(ex):
    m = ex.get("messages", [])
    roles = [x.get("role") for x in m]
    return isinstance(m, list) and ("assistant" in roles)

def min_tokens_chat(ex, min_tokens=8):
    s = tokenizer.apply_chat_template(ex["messages"], tokenize=False, add_generation_prompt=False)
    ids = tokenizer(s, add_special_tokens=False).input_ids
    return ids is not None and len(ids) >= min_tokens

dataset = dataset.filter(has_assistant)
dataset = dataset.filter(min_tokens_chat)
print("After basic filtering:", len(dataset))

# Quick peek
for i in range(2):
    s = tokenizer.apply_chat_template(dataset[i]["messages"], tokenize=False, add_generation_prompt=False)
    print("----", s[:400])


Loading: argilla/customer_assistant {}
Loading: argilla/synthetic-sft-customer-support-single-turn {}
Loading: bitext/Bitext-customer-support-llm-chatbot-training-dataset {}
Merged size (raw): 28168
After basic filtering: 28168
---- <|system|>
You are a helpful, step-by-step customer support assistant.
<|user|>
I want help seeing in which cases can I ask for a refund
<|assistant|>
Of course, I'm here to provide you with a detailed understanding of the cases in which you can request a refund. Here are some common scenarios:

1. **Product/Service Defect:** If the product or service you purchased is defective, damaged, or doesn'
---- <|system|>
You are a helpful, step-by-step customer support assistant.
<|user|>
i have a problem setting a delivery address up
<|assistant|>
I apologize for the difficulties you're experiencing while setting up your delivery address. I'm here to assist you in resolving this issue.

Setting up a delivery address is an essential part of the process, and I compl

In [9]:

n = len(dataset)
k = min(100, max(1, n // 50))  # about 2% or cap at 100
split = dataset.train_test_split(test_size=k, seed=SEED, shuffle=True)
train_dataset, eval_dataset = split["train"], split["test"]

print("train:", len(train_dataset), "eval:", len(eval_dataset))


train: 28068 eval: 100


In [10]:

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj","v_proj"],  # add k_proj/o_proj if needed
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
print("LoRA config ready.")


LoRA config ready.


In [11]:

sft_args = SFTConfig(
    output_dir="./sft_custom_results",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    logging_steps=20,
    save_steps=200,
    save_total_limit=2,
    optim="adamw_torch",
    fp16=True,                 # On multi-GPU auto mapping, bf16=True can be safer
    packing=True,              # âœ… works with template masking
    max_length=512,
    eval_strategy="steps",
    eval_steps=100,
    load_best_model_at_end=True,
    gradient_checkpointing=True,   # big VRAM saver
)
print(sft_args)


SFTConfig(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
activation_offloading=False,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
assistant_only_loss=False,
auto_find_batch_size=False,
average_tokens_across_devices=True,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
chat_template_path=None,
completion_only_loss=None,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
dataset_kwargs=None,
dataset_num_proc=None,
dataset_text_field=text,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=True,
do_predict=False,
do_train=False,
eos_token

In [15]:

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
    args=sft_args,
    peft_config=peft_config,
    formatting_func=None,
)
print("Trainer ready. Rendering one sample after template:")
print(tokenizer.apply_chat_template(train_dataset[0]["messages"], tokenize=False, add_generation_prompt=False)[:600])


Padding-free training is enabled, but the attention implementation is not set to a supported flash attention variant. Padding-free training flattens batches into a single sequence, and only the following implementations are known to reliably support this: flash_attention_2, flash_attention_3, kernels-community/flash-attn, kernels-community/flash-attn3, kernels-community/vllm-flash-attn3. Using other implementations may lead to unexpected behavior. To ensure compatibility, set `attn_implementation` in the model configuration to one of these supported options or verify that your attention mechanism can handle flattened sequences.
You are using packing, but the attention implementation is not set to a supported flash attention variant. Packing gathers multiple samples into a single sequence, and only the following implementations are known to reliably support this: flash_attention_2, flash_attention_3, kernels-community/flash-attn, kernels-community/flash-attn3, kernels-community/vllm-fla

Tokenizing train dataset:   0%|          | 0/28068 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/28068 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

Packing eval dataset:   0%|          | 0/100 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.


Trainer ready. Rendering one sample after template:
<|system|>
You are a helpful, step-by-step customer support assistant.
<|user|>
i have to see the status of purchase {{Order Number}} how do i do it
<|assistant|>
Glad you contacted us! I'm clearly cognizant that you would like to check the status of your purchase with order number {{Order Number}}. To do so, you can navigate to the 'Order Details' section on our website. This section will provide you with all the information regarding your purchase, including its current status. If you have any further questions or need additional assistance, please don't hesitate to ask. I'm here to help!




In [16]:

# ðŸš€ Train
train_result = trainer.train()
print("Training complete.")
print(train_result)


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.


Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
100,1.3954,1.356068,1.455166,798975.0,0.661916
200,1.0009,1.028595,1.016358,1598310.0,0.729692
300,0.9278,0.968086,0.963843,2397616.0,0.740626
400,0.8808,0.92954,0.924025,3196625.0,0.746657
500,0.8972,0.906796,0.909176,3996097.0,0.751434
600,0.868,0.895678,0.894779,4796144.0,0.754299


Training complete.
TrainOutput(global_step=667, training_loss=1.040491234237465, metrics={'train_runtime': 1886.3211, 'train_samples_per_second': 5.65, 'train_steps_per_second': 0.354, 'total_flos': 2.1139400014041907e+17, 'train_loss': 1.040491234237465, 'epoch': 1.0})


In [17]:

# Save LoRA adapter
trainer.save_model()  # saves to output_dir

# Quick test generation
from transformers import TextStreamer

messages = [
    {"role":"system","content":"You are a helpful customer support assistant."},
    {"role":"user","content":"How can I track my order?"},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen = model.generate(
    **inputs,
    max_new_tokens=256,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    streamer=streamer
)



<|assistant|>
I can see that you're eager to keep tabs on the progress of your order. To track your order, please provide me with your order number or the order reference number. With this information, I'll be able to provide you with the latest updates and provide any assistance you may need.

<|assistant|>
If you have any additional questions or concerns, don't hesitate to let me know. I'm here to help you every step of the way.








