In [31]:
import json
import re
import os
import torch
import pandas as pd
from pathlib import Path
from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase
from typing import Dict, Optional, List, Any, Tuple, Union
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset, DatasetDict, load_dataset, Features, Value, Sequence
from trl import SFTTrainer, SFTConfig

class Config:
    def __init__(self):
        self.project_dir = Path(os.getcwd()).parent
        self.data_dir = self.project_dir / 'Data'

configobj = Config()
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("torch cuda:", torch.version.cuda)

torch: 2.7.1
cuda available: True
torch cuda: 12.9


In [15]:
@dataclass
class DataConfig:
    # Provide an already-prepared DatasetDict OR instruct how to load/format one.
    hf_path_or_none: Optional[str] = None  # e.g. path from save_to_disk OR HF hub id
    text_field: str = "text"
    keywords_field: str = "keywords"  # list[str]
    max_train_samples: Optional[int] = None
    max_eval_samples: Optional[int] = 1000  # to keep eval fast


@dataclass
class PromptConfig:
    # Generic prompt template; keep KEYWORDS: as the "response tag" for the collator
    system_preamble: str = (
        "You are an expert keyword generator. "
        "Extract concise, relevant keywords for the document below."
    )
    response_tag: str = "KEYWORDS:"
    sep: str = "; "  # how labels are joined for the target text


@dataclass
class TrainConfig:
    model_id: str = "Qwen/Qwen2.5-3B-Instruct"
    output_dir: str = "runs/kwgen"
    load_in_4bit: bool = False
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    target_modules: Optional[List[str]] = None
    max_seq_len: int = 2048
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 4
    num_train_epochs: int = 2
    learning_rate: float = 2e-4
    logging_steps: int = 50
    eval_steps: int = 500
    save_steps: int = 500
    warmup_ratio: float = 0.03
    weight_decay: float = 0.0
    seed: int = 42
    bf16: bool = False
    fp16: bool = True
    report_to: Optional[str] = None
    gradient_checkpointing: bool = True
    save_total_limit: int = 2
    optim: str = "adamw_torch" 

In [3]:
def normalize_kw_string(s: str) -> List[str]:
    if s is None:
        return []
    raw = []
    for chunk in s.split("\n"):
        raw.extend([p for w in chunk.split(";") for p in w.split(",")])
    norm, seen = [], set()
    for k in raw:
        k2 = k.strip()
        kn = k2.casefold()
        if k2 and kn not in seen:
            seen.add(kn)
            norm.append(k2)
    return norm

def f1_keywords(preds: List[List[str]], refs: List[List[str]]) -> Dict[str, float]:
    tp = fp = fn = 0
    for p, r in zip(preds, refs):
        ps = set(x.casefold() for x in p)
        rs = set(x.casefold() for x in r)
        tp += len(ps & rs); fp += len(ps - rs); fn += len(rs - ps)
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return {"precision": precision, "recall": recall, "f1": f1}

def _clean_single(t: str) -> str:
    if not isinstance(t, str):
        return t
    # Remove ASCII control chars except '\n' (0x0A)
    t = re.sub(r'[\x00-\x09\x0B-\x0C\x0E-\x1F\x7F]', '', t)
    # Remove CR and TAB explicitly; keep '\n'
    t = t.replace('\r', '').replace('\t', ' ')
    # Collapse spaces (but not '\n')
    t = re.sub(r'[ ]{2,}', ' ', t)
    # Trim per line
    t = '\n'.join(line.strip() for line in t.split('\n'))
    return t.strip()

def clean_key_phrase(x: Union[str, List[str]]) -> Union[str, List[str]]:
    """Clean a string or list[str] of keywords; preserves '\n' in strings."""
    if isinstance(x, list):
        return [_clean_single(t) for t in x if isinstance(t, str) and t.strip()]
    return _clean_single(x)

def _dedup_preserve_order(items: List[str]) -> List[str]:
    """Case-insensitive, order-preserving dedup (keeps first-seen casing)."""
    seen = set()
    out = []
    for it in items:
        if not isinstance(it, str):
            continue
        k = it.strip()
        if not k:
            continue
        kn = k.casefold()  # Unicode-safe
        if kn not in seen:
            seen.add(kn)
            out.append(k)
    return out

def build_hf_dataset_from_pandas(df: pd.DataFrame, seed: int = 42) -> DatasetDict:

    df = df.copy()
    # Clean text (optional; comment out next line if you don't want it)
    df["text"] = df["text"].astype(str).map(_clean_single)

    def to_keywords(v) -> List[str]:
        # split to list if string; already list => keep
        kws = v.split("\n") if isinstance(v, str) else (v if isinstance(v, list) else [])
        kws = clean_key_phrase(kws) or []
        kws = [k for k in kws if k]  # drop empties
        kws = _dedup_preserve_order(kws)
        return kws

    df["keywords"] = df["key"].apply(to_keywords)
    df = df[df["keywords"].map(len) > 0].reset_index(drop=True)

    features = Features({
        "text": Value("string"),
        "keywords": Sequence(Value("string")),
    })

    ds_all = Dataset.from_pandas(
        df[["text", "keywords"]], preserve_index=False, features=features
    ).shuffle(seed=seed)

    split = ds_all.train_test_split(test_size=0.2, seed=seed)
    val_test = split["test"].train_test_split(test_size=0.5, seed=seed)

    return DatasetDict({
        "train": split["train"],
        "validation": val_test["train"],
        "test": val_test["test"],
    })


In [4]:
pd_df = pd.read_parquet(configobj.data_dir / "kw_raw.parquet")
kw_ds = build_hf_dataset_from_pandas(pd_df)
print(f"Number of training samples: {len(kw_ds['train'])}")
print(f"Number of validation samples: {len(kw_ds['validation'])}")
print(f"Number of test samples: {len(kw_ds['test'])}")
kw_ds.save_to_disk(configobj.data_dir / 'keyword_dataset')

Number of training samples: 2799
Number of validation samples: 350
Number of test samples: 350


Saving the dataset (1/1 shards): 100%|██████████| 2799/2799 [00:00<00:00, 11156.31 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 350/350 [00:00<00:00, 4525.34 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 350/350 [00:00<00:00, 4655.33 examples/s]


In [5]:
kw_ds

DatasetDict({
    train: Dataset({
        features: ['text', 'keywords'],
        num_rows: 2799
    })
    validation: Dataset({
        features: ['text', 'keywords'],
        num_rows: 350
    })
    test: Dataset({
        features: ['text', 'keywords'],
        num_rows: 350
    })
})

In [6]:
class KeywordDataModule:
    def __init__(self, data_cfg: DataConfig, prompt_cfg: PromptConfig):
        self.cfg = data_cfg
        self.prompt = prompt_cfg

    def _to_keyword_list(self, v: Any) -> List[str]:
        if v is None:
            return []
        if isinstance(v, list):
            items = [str(x).strip() for x in v if isinstance(x, (str, int, float))]
        else:
            s = str(v)
            parts = []
            for chunk in s.split("\n"):
                for p in chunk.split(";"):
                    parts.extend(p.split(","))
            items = [p.strip() for p in parts]
        seen, out = set(), []
        for k in items:
            if not k:
                continue
            kn = k.casefold()
            if kn not in seen:
                seen.add(kn)
                out.append(k)
        return out

    def _format_example(self, ex: Dict[str, Any]) -> Dict[str, str]:
        doc = (ex.get(self.cfg.text_field, "") or "").strip()
        kw_list = self._to_keyword_list(ex.get(self.cfg.keywords_field, []))
        labels = self.prompt.sep.join(kw_list)
        prompt = (
            f"{self.prompt.system_preamble}\n\n"
            f"DOCUMENT:\n{doc}\n\n"
            f"{self.prompt.response_tag}"
        )
        return {"text": prompt, "labels": labels}

    def load(self, dsd: Optional[DatasetDict] = None) -> DatasetDict:
        if dsd is None:
            if self.cfg.hf_path_or_none is None:
                raise ValueError("Provide a DatasetDict or set DataConfig.hf_path_or_none.")
            try:
                dsd = DatasetDict.load_from_disk(self.cfg.hf_path_or_none)
            except Exception:
                dsd = load_dataset(self.cfg.hf_path_or_none)

        mapped = DatasetDict()
        for split in dsd.keys():
            ds = dsd[split]
            if split == "train" and self.cfg.max_train_samples:
                ds = ds.select(range(min(self.cfg.max_train_samples, len(ds))))
            if split in ("validation", "test") and self.cfg.max_eval_samples:
                ds = ds.select(range(min(self.cfg.max_eval_samples, len(ds))))
            ds = ds.map(
                self._format_example,
                remove_columns=[c for c in ds.column_names if c not in ("text", "labels")],
            )
            ds = ds.filter(lambda ex: bool(ex["labels"] and ex["labels"].strip()))
            mapped[split] = ds

        for s in mapped.keys():
            cols = set(mapped[s].column_names)
            if not {"text", "labels"}.issubset(cols):
                raise RuntimeError(f"Split '{s}' must contain 'text' and 'labels', got: {cols}")
        return mapped

In [25]:
@dataclass
class KeywordCompletionCollator:
    tokenizer: PreTrainedTokenizerBase
    response_template: str = "KEYWORDS:"
    max_length: int = 2048
    pad_to_multiple_of: Optional[int] = 8
    add_eos: bool = True

    # ---------- helpers ----------
    def _pad_right(
        self,
        seqs: List[List[int]],
        pad_id: int,
        max_len: int,
        pad_multiple: Optional[int] = None,
    ) -> torch.Tensor:
        # truncate
        seqs = [s[:max_len] for s in seqs]
        tgt_len = max(len(s) for s in seqs) if seqs else 0
        if pad_multiple and tgt_len % pad_multiple:
            tgt_len = ((tgt_len + pad_multiple - 1) // pad_multiple) * pad_multiple
        out = torch.full((len(seqs), tgt_len), pad_id, dtype=torch.long)
        for i, s in enumerate(seqs):
            out[i, :len(s)] = torch.tensor(s, dtype=torch.long)
        return out

    def _get_text_labels(self, ex: Dict[str, Any]) -> Tuple[str, str]:
        if "text" in ex and "labels" in ex:
            return ex["text"], ex["labels"]
        if "prompt" in ex and "response" in ex:
            return ex["prompt"], ex["response"]
        raise KeyError(f"Example must contain ('text','labels') or ('prompt','response'), got: {list(ex.keys())}")

    # ---------- main ----------
    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Case A: pre-tokenized items: {'input_ids': ..., 'labels': ...} (attention_mask optional)
        if "input_ids" in batch[0]:
            pad_id = self.tokenizer.pad_token_id
            if pad_id is None:
                # ensure pad token exists
                pad_id = self.tokenizer.eos_token_id or 0

            # Collect fields
            input_ids = [ex["input_ids"] for ex in batch]
            labels_in = [ex.get("labels", None) for ex in batch]
            attn = [ex.get("attention_mask", None) for ex in batch]

            # Truncate/pad
            input_ids_t = self._pad_right(input_ids, pad_id=pad_id, max_len=self.max_length,
                                          pad_multiple=self.pad_to_multiple_of)

            # attention_mask: 1 where not pad
            if any(a is not None for a in attn):
                # If some provided, recompute from padded input_ids to be safe
                attention_mask_t = (input_ids_t != pad_id).long()
            else:
                attention_mask_t = (input_ids_t != pad_id).long()

            # labels: if list[int] provided, pad with -100; else create ignore-only labels
            labels_list: List[List[int]] = []
            for ids, lab in zip(input_ids, labels_in):
                if isinstance(lab, list) and all(isinstance(x, int) for x in lab):
                    labels_list.append(lab)
                else:
                    # default: ignore loss everywhere (trainer may not expect this, but safe)
                    labels_list.append([-100] * len(ids))

            labels_t = self._pad_right(labels_list, pad_id=-100, max_len=self.max_length,
                                       pad_multiple=self.pad_to_multiple_of)

            # If some labels seqs are shorter than inputs, ensure padding positions are -100
            if labels_t.shape[1] < input_ids_t.shape[1]:
                # expand labels to match inputs
                expanded = torch.full_like(input_ids_t, -100)
                expanded[:, :labels_t.shape[1]] = labels_t
                labels_t = expanded
            elif labels_t.shape[1] > input_ids_t.shape[1]:
                labels_t = labels_t[:, :input_ids_t.shape[1]]

            return {
                "input_ids": input_ids_t,
                "attention_mask": attention_mask_t,
                "labels": labels_t,
            }

        # Case B: string format -> build prompt + labels and tokenize here
        prompts, targets = [], []
        for ex in batch:
            p, t = self._get_text_labels(ex)
            prompts.append(p)
            targets.append("" if t is None else t)

        eos = self.tokenizer.eos_token if (self.add_eos and self.tokenizer.eos_token) else ""
        full_texts = [
            p + (" " if (t and not p.endswith(" ")) else "") + t + eos
            for p, t in zip(prompts, targets)
        ]

        enc = self.tokenizer(
            full_texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"]
        attention_mask = enc["attention_mask"]

        # Find boundary after "prompt + response_template"
        tagged_prompts = [
            p if p.strip().endswith(self.response_template)
            else (p.rstrip() + " " + self.response_template)
            for p in prompts
        ]
        tag_enc = self.tokenizer(
            tagged_prompts,
            padding=True,
            truncation=True,
            max_length=input_ids.size(1),
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        tag_lens = tag_enc["attention_mask"].sum(dim=1)
        seq_lens = attention_mask.sum(dim=1)

        labels = input_ids.clone()
        labels[:] = -100
        for i in range(input_ids.size(0)):
            start = int(tag_lens[i].item())
            end = int(seq_lens[i].item())
            if start < end:
                labels[i, start:end] = input_ids[i, start:end]

        return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [26]:
class QLoRALoader:
    def __init__(self, cfg: TrainConfig):
        self.cfg = cfg
        self.tokenizer = None
        self.model = None

    def _assert_no_bnb(self):
        # If bitsandbytes is importable, PEFT will try to use its path and load Triton/CUDA headers.
        try:
            import bitsandbytes  # noqa: F401
            raise RuntimeError(
                "bitsandbytes is installed but 4-bit is disabled. "
                "Uninstall bitsandbytes (pip uninstall -y bitsandbytes) to avoid cuda.h build."
            )
        except Exception:
            # OK: not installed (ImportError) or we raised above
            pass

    def _guess_target_modules(self, model: torch.nn.Module) -> List[str]:
        candidates = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        present, seen, uniq = [], set(), []
        for n, _ in model.named_modules():
            for c in candidates:
                if n.endswith(c):
                    present.append(c)
        for x in present:
            if x not in seen:
                seen.add(x); uniq.append(x)
        return uniq or ["q_proj", "k_proj", "v_proj", "o_proj"]

    def load(self):
        print(f"Loading tokenizer: {self.cfg.model_id}")
        tok = AutoTokenizer.from_pretrained(self.cfg.model_id, use_fast=True)
        tok.padding_side = "right"
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token

        print(f"Loading model: {self.cfg.model_id}")
        # Force-disable 4-bit completely
        if self.cfg.load_in_4bit:
            print("[INFO] Overriding: load_in_4bit requested but disabled due to CUDA header issues.")
        self.cfg.load_in_4bit = False
        self._assert_no_bnb()  # ensure PEFT won’t import its bnb path

        dtype = torch.bfloat16 if self.cfg.bf16 else torch.float16
        model = AutoModelForCausalLM.from_pretrained(
            self.cfg.model_id,
            dtype=dtype,
            device_map="auto",   # do NOT call .to('cuda') after this
        )

        target_modules = self.cfg.target_modules or self._guess_target_modules(model)
        print("Using LoRA target_modules:", target_modules)
        lora_cfg = LoraConfig(
            r=self.cfg.lora_r,
            lora_alpha=self.cfg.lora_alpha,
            lora_dropout=self.cfg.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=target_modules,
        )
        model = get_peft_model(model, lora_cfg)
        model.print_trainable_parameters()

        self.tokenizer, self.model = tok, model
        return tok, model

In [27]:
class KwGenTrainer:
    def __init__(self, train_cfg: TrainConfig, prompt_cfg: PromptConfig):
        self.cfg = train_cfg
        self.prompt = prompt_cfg
        self.trainer = None

    def build(self, tok, model, dsd_mapped: DatasetDict):
        collator = KeywordCompletionCollator(
            tokenizer=tok,
            response_template=self.prompt.response_tag,
            max_length=self.cfg.max_seq_len,   # <- collator enforces seq length
            pad_to_multiple_of=8,
        )
    
        sft_args = SFTConfig(
            output_dir=self.cfg.output_dir,
            packing=False,
            per_device_train_batch_size=self.cfg.per_device_train_batch_size,
            gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
            num_train_epochs=self.cfg.num_train_epochs,
            learning_rate=self.cfg.learning_rate,
            lr_scheduler_type="cosine",
            logging_steps=self.cfg.logging_steps,
            eval_strategy="steps",
            eval_steps=self.cfg.eval_steps,
            save_steps=self.cfg.save_steps,
            warmup_ratio=self.cfg.warmup_ratio,
            weight_decay=self.cfg.weight_decay,
            bf16=self.cfg.bf16,
            fp16=self.cfg.fp16,
            optim=getattr(self.cfg, "optim", "adamw_torch"),
            report_to=self.cfg.report_to,
            seed=self.cfg.seed,
            gradient_checkpointing=self.cfg.gradient_checkpointing,
            save_total_limit=self.cfg.save_total_limit,
        )
    
        # IMPORTANT: do NOT pass tokenizer=, dataset_text_field=, or max_seq_length=
        self.trainer = SFTTrainer(
            model=model,
            args=sft_args,
            train_dataset=dsd_mapped.get("train"),
            eval_dataset=dsd_mapped.get("validation"),
            data_collator=collator,
        )
        return self.trainer

    @torch.no_grad()
    def evaluate_keywords(self, tok, model, eval_ds: Dataset, sample_size: int = 512, gen_kwargs=None) -> Dict[str, float]:
        if gen_kwargs is None:
            gen_kwargs = dict(max_new_tokens=96, temperature=0.2, do_sample=False, top_p=1.0, repetition_penalty=1.05)

        n = min(sample_size, len(eval_ds))
        prompts = [eval_ds[i]["text"] for i in range(n)]
        refs = [eval_ds[i]["labels"] for i in range(n)]

        preds_kw, refs_kw = [], []
        model.eval()
        device = model.device if hasattr(model, "device") else None

        for p, r in zip(prompts, refs):
            inputs = tok(p, return_tensors="pt")
            if device is not None:
                inputs = {k: v.to(device) for k, v in inputs.items()}
            out = model.generate(**inputs, **gen_kwargs)
            gen_txt = tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
            preds_kw.append(normalize_kw_string(gen_txt))
            refs_kw.append(normalize_kw_string(r))

        return f1_keywords(preds_kw, refs_kw)

    def train_and_eval(self, tok, model, dsd_mapped: DatasetDict) -> Dict[str, float]:
        self.build(tok, model, dsd_mapped)
        # Enable fast downloads if user has hf-transfer installed
        os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
        self.trainer.train()

        if "validation" in dsd_mapped and len(dsd_mapped["validation"]) > 0:
            metrics = self.evaluate_keywords(tok, self.trainer.model, dsd_mapped["validation"])
        else:
            metrics = {"precision": float("nan"), "recall": float("nan"), "f1": float("nan")}

        adapter_dir = os.path.join(self.cfg.output_dir, "adapter")
        self.trainer.model.save_pretrained(adapter_dir)
        tok.save_pretrained(adapter_dir)
        with open(os.path.join(self.cfg.output_dir, "metrics.json"), "w") as f:
            json.dump(metrics, f, indent=2)
        return metrics

class KeywordGenPipeline:
    def __init__(self, data_cfg: DataConfig, prompt_cfg: PromptConfig, train_cfg: TrainConfig):
        self.data_cfg = data_cfg
        self.prompt_cfg = prompt_cfg
        self.train_cfg = train_cfg
        self.data_module = KeywordDataModule(data_cfg, prompt_cfg)
        self.loader = QLoRALoader(train_cfg)
        self.runner = KwGenTrainer(train_cfg, prompt_cfg)

    def run(self, dsd: Optional[DatasetDict] = None) -> Dict[str, float]:
        dsd_mapped = self.data_module.load(dsd)
        tok, model = self.loader.load()
        metrics = self.runner.train_and_eval(tok, model, dsd_mapped)
        print("Final metrics:", metrics)
        return metrics

In [28]:
data_cfg = DataConfig(
    hf_path_or_none=configobj.data_dir / 'keyword_dataset',
    text_field="text",
    keywords_field="keywords",
    max_train_samples=None,
    max_eval_samples=800,
)
prompt_cfg = PromptConfig(
    system_preamble="You are an expert keyword generator. Extract concise, relevant keywords for the document below.",
    response_tag="KEYWORDS:",
    sep="; ",
)

In [29]:
train_cfg = TrainConfig(
    model_id="Qwen/Qwen2.5-3B-Instruct",   # swap here to try other models (Phi-3.5, Llama-3.2-3B, Gemma-2-2B, etc.)
    output_dir="runs/qwen25_3b_kwgen",
    load_in_4bit=False,
    lora_r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=None,                   # auto-detects q_proj/k_proj/... if None
    max_seq_len=2048,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    learning_rate=2e-4,
    logging_steps=50,
    eval_steps=500,
    save_steps=500,
    warmup_ratio=0.03,
    weight_decay=0.0,
    seed=42,
    bf16=False, fp16=True,
    report_to=None,
    gradient_checkpointing=True,
    save_total_limit=2,
)

In [30]:
pipe = KeywordGenPipeline(data_cfg, prompt_cfg, train_cfg)
pipe.run()

Loading tokenizer: Qwen/Qwen2.5-3B-Instruct
Loading model: Qwen/Qwen2.5-3B-Instruct


Loading checkpoint shards: 100%|██████████| 2/2 [00:13<00:00,  6.57s/it]


Using LoRA target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
trainable params: 29,933,568 || all params: 3,115,872,256 || trainable%: 0.9607


The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
500,0.0,,1.960327,2126990.0,0.0


The following generation flags are not valid and may be ignored: ['temperature', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


NameError: name 'json' is not defined

In [34]:
data_module = KeywordDataModule(data_cfg, prompt_cfg)
dsd_mapped = data_module.load()
train_ds = dsd_mapped["train"]

In [35]:
# 1) Tokenizer used by the collator
tok = AutoTokenizer.from_pretrained(train_cfg.model_id, use_fast=True)
tok.padding_side = "right"
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

In [36]:
# 2) Collator (your class)
collator = KeywordCompletionCollator(
    tokenizer=tok,
    response_template=prompt_cfg.response_tag,  # "KEYWORDS:"
    max_length=train_cfg.max_seq_len,
    pad_to_multiple_of=8,
    add_eos=True,
)

In [37]:
# 3) Grab one batch
loader = DataLoader(train_ds, batch_size=2, shuffle=False, collate_fn=collator)
batch = next(iter(loader))

In [38]:
# 4) Inspect decoded inputs and supervised tokens
def show_example(i: int):
    inp_ids = batch["input_ids"][i]
    lbl_ids = batch["labels"][i]
    attn    = batch["attention_mask"][i]

    full_text = tok.decode(inp_ids, skip_special_tokens=False)
    target_text = tok.decode(lbl_ids[lbl_ids != -100], skip_special_tokens=False)

    print(f"\n===== Example {i} =====")
    print("Ends with response tag? ->", full_text.rstrip().endswith(prompt_cfg.response_tag))
    print("\n[Prompt+Target decoded]\n", full_text)
    print("\n[Supervised target (labels!=-100) decoded]\n", target_text)
    print("\n[counts]",
          " total_tokens=", int(attn.sum().item()),
          " supervised_tokens=", int((lbl_ids != -100).sum().item()))

In [39]:
show_example(0)
if batch["input_ids"].size(0) > 1:
    show_example(1)


===== Example 0 =====
Ends with response tag? -> False

[Prompt+Target decoded]
 You are an expert keyword generator. Extract concise, relevant keywords for the document below.

DOCUMENT:
Plasticity of the Human Auditory Cortex Induced by Discrimination Learning of Non-Native, Mora-Timed Contrasts of the Japanese Language
In this magnetoencephalographic (MEG) study, we examined with high temporal resolution the traces of learning in the speech-dominant left-hemispheric auditory cortex as a function of newly trained mora-timing. In Japanese, the "mora" is a temporal unit that divides words into almost isochronous segments (e.g., na-ka-mu-ra and to-o-kyo-o each comprises four mora). Changes in the brain responses of a group of German and Japanese subjects to differences in the mora structure of Japanese words were compared. German subjects performed a discrimination training in 10 sessions of 1.5 h each day. They learned to discriminate Japanese pairs of words (in a consonant, anni --an