In [None]:
format_def = f"""Issues in citation formatting such as a missing bracket and using the wrong style of citing.
    a) Due to preprocessing errors of the source dataset, some words contain hyphens that do not require it, and some are missing hyphens where it is required. Please ignore these types of formatting issues.
    b) Highlight the word/citation in which the formatting issue occurs in and not only the issue within the word/citation.
    c) Formatting issues appear as either citations or parts of a citation.
    Examples of formatting issues include:
        i) Narrative citation missing year: “Vatswani et al.” -> should be “Vatswani et al. (2020)”
        ii) Wrong citation style: “In (Vatswani et al., 2019)” -> should be “in Vatswani et al. (2019)”
        iii) Wrong use of footnotes: "Vastwani et al. 1" -> should include the year or be reformatted as a proper footnote."""

unsupp_def = f"""claim about prior work or statistics w/o citation or evidence. 
    a) The author should cite at every first mention of a study, paper, shared task, competition or dataset.
    b) Specific information to a niche topic, despite sounding like it should be known in that topic of study, should be cited.
    c) If a claim is made and is obvious to be a natural deduction from previous statements through common sense (i.e not requiring expert knowledge), then this claim does not fall under ‘Unsupported claim’. For example:
        i) “However, creating a large and suitable set of questions for supporting narrative comprehension is both time-consuming and cognitively demanding.” -> it is obvious that creating a dataset is time consuming and mentally demanding.
    d) Any mention of “recent works” should be backed up with citations to the works.
    e) Unsupported claim issues appear as segments, phrases, sub-sentences or full sentences.
    Examples of unsupported claims include:
        i) Missing citations for mentions of 'recent works': “and there are many recent works that explore this topic”,
        ii) Mention of a previous work and claim without citation: “..., while in a previous study, the authors claim …”,
        iii) Mentioning of a specific setup of a task without citation to the work: ".. BERT was used in an AES task trained on essays .." """

lacksynth_def = f"""occurs when either:
    a) The author describes or cites papers without connecting them to their own work/argument 
    b) Or only follows up the summary of previous works with their own contribution without explicitly highlighting the gap their work intends to research.
    c) It does not articulate the author's perspective or motivation.
    d) A lack of argument/opinion in the first paragraph is permissible as it serves to be the foundation of the author's argument 
    e) Lacks synthesis issues appear either as single sentences or multiple sentences.
    Examples of lack of synthesis include:
        i) No elaboration of own contribution/argument:"Following early neural approaches to question answering, many subsequent studies adopt a pipeline architecture consisting of retrieval and comprehension components. The retrieval component focuses on identifying relevant documents or passages from a large corpus, while the comprehension component extracts an answer span from the retrieved text. Initial models relied on recurrent neural networks with attention mechanisms to encode questions and contexts (Seo et al., 2017; Wang et al., 2017)."
        ii)  No explanation of the cited works and relation to their own work: “Recently, several studies have explored the use of prompting techniques with pre-trained language models to influence model outputs or access latent knowledge (Brown et al., 2020; Gao et al., 2021; Liu et al., 2021; Wei et al., 2022).” """

coherence_def = f"""connection between cited works is abrupt, lacking relation to each other. It is unclear how one mentioned work is relevant to a prior mentioned work. 
    a) Sentences are not transitioned from one to another.
    b) The relationship between sentences describing papers is implied but not explicitly stated.
    c) Coherence issues appear only as multiple sentences.
    Examples of coherence issues include:
        i) Relation between mentioned works is not explicit: “Smith (2020) identified a relationship between personal belief systems and ethical decision-making frameworks. Moral foundation theory proposes several core dimensions of moral reasoning, including harm, fairness, and authority (Jones, 2015). Audience adaptation has been explored in computational argumentation. Lee et al. (2019) applied moral categories to argument generation tasks. Human annotators often disagree when labeling moral dimensions in text (Nguyen et al., 2018).”
        ii) Lack of transitions between sentences: “Recent studies have explored various techniques for enhancing model performance. Smith et al. (2020) introduced a novel architecture that significantly improves accuracy on benchmark datasets. Additionally, Johnson and Lee (2019) proposed a data augmentation method that increases training data diversity.” 
        iii) No explanation of the cited works and relation to their own work: “Recently, several studies have explored the use of prompting techniques with pre-trained language models to influence model outputs or access latent knowledge (Brown et al., 2020; Gao et al., 2021; Liu et al., 2021; Wei et al., 2022).” """

Training loop

In [None]:
# python sft_train_decoder_prompt_completion.py --config config.json --category "$CAT"

In [None]:
import argparse
import os 
import json 
from pathlib import Path

def load_config(config_path: str) -> dict:
    with open(config_path, "r", encoding="utf-8") as f:
        return json.load(f)

def merge_dicts(base: dict, override: dict) -> dict:
    out = dict(base)
    for k, v in (override or {}).items():
        out[k] = v
    return out

parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="Path to config.json")
parser.add_argument("--category", required=True, help="Category name from config.json")
args_cli = parser.parse_args()

cfg = load_config(args_cli.config)

defaults = cfg.get("defaults", {})
base_model = cfg.get("base_model", "meta-llama/Llama-3.1-8B-Instruct")

defaults = cfg.get("defaults", {})
cat_cfg = {...}  # the config entry for the selected category

def merge_dicts(base, override):
    out = dict(base)
    out.update(override)
    return out

run_cfg = merge_dicts(defaults, cat_cfg)

# If set in config, use it; otherwise derive from the category's data folder
# Save into: <data_folder>/models/<category>
train_path = run_cfg["train_path"]
data_dir = Path(train_path).resolve().parent

output_subdir = run_cfg.get("output_subdir", "models")  # configurable if you want
run_name = run_cfg.get("run_name", args_cli.category)  # configurable run folder name

out_dir = str(data_dir / run_name)

# Find the category block
cat_cfg = None
for c in cfg.get("categories", []):
    if c.get("name") == args_cli.category:
        cat_cfg = c
        break
if cat_cfg is None:
    raise ValueError(f"Category '{args_cli.category}' not found in config.json")

# Allow per-category overrides of defaults
run_cfg = merge_dicts(defaults, cat_cfg)

# Required per-category dataset paths
train_path = run_cfg["train_path"]
dev_path = run_cfg.get("dev_path")
eval_path = run_cfg.get("eval_path")

eval_split = run_cfg.get("eval_split", "dev")  # dev or eval
eval_path_for_training = dev_path if eval_split == "dev" else eval_path


In [13]:
import json
from typing import Dict, List, Any
from torch.utils.data import Dataset

class PromptCompletionDataset(Dataset):
    def __init__(self, path: str, prompt_key: str = "prompt", completion_key: str = "completion"):
        self.rows: List[Dict[str, Any]] = []
        self.prompt_key = prompt_key
        self.completion_key = completion_key

        # Read entire file once; decide JSON vs JSONL
        with open(path, "r", encoding="utf-8-sig") as f:
            text = f.read().strip()

        if not text:
            raise ValueError(f"{path} is empty")

        # Heuristic: if it starts with '[' or '{', try JSON first (covers .json and pretty-printed arrays)
        if text[0] in "[{":
            try:
                obj = json.loads(text)
                self.rows = self._normalize_json(obj, path)
                self._validate_rows(path)
                return
            except json.JSONDecodeError:
                # Fall back to JSONL parsing below
                pass

        # JSONL fallback (robust: skips empty lines, reports bad lines clearly)
        self.rows = []
        with open(path, "r", encoding="utf-8-sig") as f:
            for lineno, line in enumerate(f, start=1):
                line = line.strip()
                if not line:
                    continue
                try:
                    self.rows.append(json.loads(line))
                except json.JSONDecodeError as e:
                    preview = line[:200].replace("\n", "\\n")
                    raise ValueError(
                        f"Failed to parse JSON on line {lineno} in {path}: {e}\n"
                        f"Line preview: {preview}"
                    ) from e

        self._validate_rows(path)

    def _normalize_json(self, obj: Any, path: str) -> List[Dict[str, Any]]:
        # Accept: list of dicts
        if isinstance(obj, list):
            return obj
        # Accept: {"data": [...]} or {"rows": [...]} (common variants)
        if isinstance(obj, dict):
            for k in ("data", "rows", "examples", "items"):
                if k in obj and isinstance(obj[k], list):
                    return obj[k]
        raise ValueError(
            f"{path} parsed as JSON but is not a list of examples or a dict containing a list "
            f"(expected e.g. [{{...}}, ...] or {{'data': [...]}}). Got: {type(obj)}"
        )

    def _validate_rows(self, path: str) -> None:
        if not isinstance(self.rows, list):
            raise ValueError(f"{path}: expected a list of examples, got {type(self.rows)}")

        for i, r in enumerate(self.rows):
            if not isinstance(r, dict):
                raise ValueError(f"{path}: example {i} is not a dict (got {type(r)})")
            if self.prompt_key not in r or self.completion_key not in r:
                raise ValueError(
                    f"{path}: example {i} missing keys "
                    f"'{self.prompt_key}' and/or '{self.completion_key}'. Keys: {list(r.keys())}"
                )

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, idx: int) -> Dict[str, str]:
        r = self.rows[idx]
        return {"prompt": r[self.prompt_key], "completion": r[self.completion_key]}


In [32]:
from dataclasses import dataclass
from typing import List, Dict, Any
import torch

IGNORE_INDEX = -100


@dataclass
class PromptCompletionCollator:
    tokenizer: Any
    max_length: int = 2048

    def __call__(self, batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
        """
        Batch is a list of:
          {"prompt": str, "completion": str}

        We build:
          input_ids      = tokenizer(prompt + completion)
          attention_mask = usual
          labels         = input_ids, but prompt tokens masked to IGNORE_INDEX
        """
        prompts = [ex["prompt"] for ex in batch]
        completions = [ex["completion"] for ex in batch]

        # Important: ensure completion starts immediately after prompt
        # (your data should already include leading space/newline if needed)
        full_texts = [p + c for p, c in zip(prompts, completions)]

        enc_full = self.tokenizer(
            full_texts,
            truncation=True,
            max_length=self.max_length,
            padding=True,
            return_tensors="pt",
        )

        # Tokenize prompts alone to know prompt token lengths
        enc_prompt = self.tokenizer(
            prompts,
            truncation=True,
            max_length=self.max_length,
            padding=True,
            return_tensors="pt",
        )

        input_ids = enc_full["input_ids"]
        attention_mask = enc_full["attention_mask"]

        labels = input_ids.clone()

        # Mask out prompt tokens from loss
        for i in range(len(batch)):
            # Count non-padding tokens in prompt
            prompt_len = int(enc_prompt["attention_mask"][i].sum().item())
            labels[i, :prompt_len] = IGNORE_INDEX

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


In [None]:
import os, json, socket, platform
from datetime import datetime
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# =========================
# A100-tuned defaults (safe)
# =========================
# A100 supports bf16 very well; prefer bf16 over fp16.
os.environ.setdefault("BF16", "1")
os.environ.setdefault("FP16", "0")

# ---- Settings from config ----
model_name = run_cfg.get("model_name", base_model)
eval_path = eval_path_for_training

print("Settings (from config):")
print(f"Category: {args_cli.category}")
print(f"Model: {model_name}")
print(f"Train data: {train_path}")
print(f"Eval data: {eval_path}")
print(f"Output dir: {out_dir}")

max_length = int(run_cfg.get("max_length", 2048))
per_device_train_batch_size = int(run_cfg.get("per_device_train_batch_size", 1))
per_device_eval_batch_size = int(run_cfg.get("per_device_eval_batch_size", per_device_train_batch_size))
gradient_accumulation_steps = int(run_cfg.get("gradient_accumulation_steps", 16))

learning_rate = float(run_cfg.get("learning_rate", 2e-4))
num_train_epochs = float(run_cfg.get("num_train_epochs", 1))
warmup_ratio = float(run_cfg.get("warmup_ratio", 0.03))

logging_steps = int(run_cfg.get("logging_steps", 25))
save_steps = int(run_cfg.get("save_steps", 500))
save_reason = str(run_cfg.get("save_reason", "config-driven run"))

bf16 = bool(run_cfg.get("bf16", True))
fp16 = bool(run_cfg.get("fp16", False))

lora_r = int(run_cfg.get("lora_r", 16))
lora_alpha = int(run_cfg.get("lora_alpha", 32))
lora_dropout = float(run_cfg.get("lora_dropout", 0.05))

torch_compile = bool(run_cfg.get("torch_compile", False))
use_flash_attn = bool(run_cfg.get("flash_attn", False))

num_workers = int(run_cfg.get("dataloader_num_workers", 4))
max_grad_norm = float(run_cfg.get("max_grad_norm", 1.0))


In [None]:
# ---- Load tokenizer ----
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Tokenizer loaded.")

# ---- Load model with 4-bit quantization (QLoRA) ----
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="bfloat16" if bf16 else "float16",
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype="bfloat16" if bf16 else ("float16" if fp16 else None),
)
print("Model loaded.")

In [None]:
# Optional: flash-attn v2 (if installed and supported)
if use_flash_attn:
    try:
        model.config.attn_implementation = "flash_attention_2"
        print("Enabled flash_attention_2")
    except Exception as e:
        print("Could not enable flash_attention_2:", e)

print(f"Model loaded with {sum(p.numel() for p in model.parameters())} parameters.")

print("Prepping for kbit training and adding LoRA adapters...")
# Prepare for k-bit training + add LoRA adapters
model = prepare_model_for_kbit_training(model)

lora_cfg = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()

# Optional: torch.compile (PyTorch 2.x). Can help throughput; sometimes finicky.
if torch_compile:
    try:
        import torch
        model = torch.compile(model)
        print("Enabled torch.compile")
    except Exception as e:
        print("Could not enable torch.compile:", e)

In [33]:
print("Loading datasets...")

train_ds = PromptCompletionDataset(train_path)
eval_ds = PromptCompletionDataset(eval_path) if eval_path and os.path.exists(eval_path) else None
collator = PromptCompletionCollator(tokenizer=tokenizer, max_length=max_length)


Loading datasets...


ValueError: train.json: example 0 missing keys 'prompt' and/or 'completion'. Keys: ['span', 'document', 'reason', 'start', 'end', 'label']

In [None]:
# ---- Training args (A100-tuned) ----
args = TrainingArguments(
    output_dir=out_dir,
    overwrite_output_dir=False,  # keep checkpoints for resume

    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,

    # A100: bf16 is preferred
    bf16=bf16,
    fp16=fp16,

    # Optim + scheduler
    learning_rate=learning_rate,
    warmup_ratio=warmup_ratio,
    lr_scheduler_type="cosine",
    optim="adamw_torch",

    # Throughput improvements
    dataloader_num_workers=int(os.environ.get("NUM_WORKERS", "4")),
    max_grad_norm=float(os.environ.get("MAX_GRAD_NORM", "1.0")),

    gradient_checkpointing=True,  # reduces activation memory; helps longer context
    tf32=True,                    # A100 supports TF32; can improve matmul speed
    # Stability

    # Logging/saving
    logging_steps=logging_steps,
    save_steps=save_steps,
    save_total_limit=2,
    save_safetensors=True,
    logging_dir=os.path.join(out_dir, "logs"),

    # Eval
    evaluation_strategy="steps" if eval_ds is not None else "no",
    eval_steps=save_steps if eval_ds is not None else None,

    # Optional: save a bit of overhead with no external reporters
    report_to="none",

    # Optional: speed by grouping similar lengths (needs dataset to return lengths or use HF datasets)
    group_by_length=False,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
)

def _latest_checkpoint(output_dir: str):
    if not os.path.isdir(output_dir):
        return None
    ckpts = []
    for name in os.listdir(output_dir):
        if name.startswith("checkpoint-"):
            p = os.path.join(output_dir, name)
            if os.path.isdir(p):
                try:
                    step = int(name.split("-")[-1])
                except Exception:
                    step = -1
                ckpts.append((step, p))
    if not ckpts:
        return None
    ckpts.sort(key=lambda x: x[0])
    return ckpts[-1][1]

def write_save_reason_json(output_dir: str, reason: str, extra: dict = None):
    os.makedirs(output_dir, exist_ok=True)
    path = os.path.join(output_dir, "save_reasons.json")

    entry = {
        "timestamp_utc": datetime.utcnow().isoformat() + "Z",
        "reason": reason,
        "model_name": model_name,
        "train_path": train_path,
        "eval_path": eval_path if eval_ds is not None else None,
        "out_dir": output_dir,
        "host": socket.gethostname(),
        "platform": platform.platform(),
        "a100_tuned": True,
    }
    if extra:
        entry.update(extra)

    if os.path.exists(path):
        with open(path, "r", encoding="utf-8") as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError:
                data = []
    else:
        data = []

    if not isinstance(data, list):
        data = [data]

    data.append(entry)

    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

# ---- Resume if possible ----
resume_ckpt = _latest_checkpoint(out_dir)
if resume_ckpt:
    print(f"Resuming from latest checkpoint: {resume_ckpt}")
    train_output = trainer.train(resume_from_checkpoint=resume_ckpt)
else:
    print("No checkpoint found; starting fresh.")
    train_output = trainer.train()

# ---- Final save (LoRA adapters + tokenizer) ----
trainer.save_model(out_dir)
tokenizer.save_pretrained(out_dir)

write_save_reason_json(
    out_dir,
    reason=save_reason,
    extra={
        "final_save": True,
        "global_step": int(getattr(trainer.state, "global_step", -1)),
        "train_runtime_sec": float(getattr(train_output, "metrics", {}).get("train_runtime", -1)),
        "lora": {"r": lora_r, "alpha": lora_alpha, "dropout": lora_dropout},
        "quantization": "4bit_nf4",
        "max_length": max_length,
        "per_device_train_batch_size": per_device_train_batch_size,
        "grad_accum": gradient_accumulation_steps,
        "lr": learning_rate,
        "bf16": bf16,
        "fp16": fp16,
        "gradient_checkpointing": True,
        "tf32": True,
        "flash_attn": use_flash_attn,
        "torch_compile": torch_compile,
    },
)

print(f"Saved (LoRA adapters + tokenizer) to: {out_dir}")
print(f"Wrote save reasons to: {os.path.join(out_dir, 'save_reasons.json')}")