# DoAug 파라프레이즈 증강 파이프라인
본 노트북은 **DoAug** 논문의 핵심 단계인 SFT → DPO → Selective Coreset Augmentation 을 다시 구현합니다.
다운스트림 태스크 성능은 다루지 않고, _다양하고 의미를 유지하는 파라프레이즈 텍스트_ 생성에 집중합니다.

In [None]:
# (실행 환경에 따라 필요 모듈을 설치하세요)
%pip install -q torch==2.5.1 transformers==4.45.2 datasets sentence-transformers peft accelerate trl==0.11.3 scikit-learn tensorboard

In [None]:
from pathlib import Path
import torch, os, json, random, ast, tensorboard
from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding,
)
from transformers.trainer_utils import get_last_checkpoint
from peft import LoraConfig, get_peft_model
from trl import DPOTrainer, DPOConfig
from sentence_transformers import SentenceTransformer, util
from tqdm.auto import tqdm
import torch.nn.functional as F


os.environ["CUDA_VISIBLE_DEVICES"] = "0"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
WORKDIR = Path("./doaug_artifacts")
WORKDIR.mkdir(exist_ok=True)
HF_TOKEN = os.environ["HF_TOKEN"]
SYSTEM_MESSAGE = "You are a helpful assistant that only paraphrases."

print(f"Using {DEVICE}")

1️⃣ Supervised Fine‑Tuning (SFT)
---
– Build 100 k (sentence, paraphrase) pairs.  
– Train a LoRA adapter.  
– Merge adapter into the base model.


#### (i) Create DSFT_100k without overlap with later DDPO set


In [None]:
raw_all = load_dataset("humarin/chatgpt-paraphrases", split="train")

In [None]:
import ast

raw_dsft = raw_all.shuffle(seed=123).select(range(20_000))  # 20 k sources

pairs = []
dsft_sources = set()
for ex in raw_dsft:
    orig = ex["text"]
    dsft_sources.add(orig)
    pars = ex["paraphrases"]
    if isinstance(pars, str):
        pars = ast.literal_eval(pars)
    for para in pars:
        pairs.append({"sentence": orig, "paraphrase": para})

assert len(pairs) == 100_000, "DSFT size should be exactly 100 k"

dsft_path = WORKDIR / "DSFT_100k.jsonl"
with open(dsft_path, "w", encoding="utf-8") as f:
    for p in pairs:
        f.write(json.dumps(p, ensure_ascii=False) + "\n")

print("DSFT saved", dsft_path)

In [None]:
pairs = []
dsft_sources = set()
for ex in raw_dsft:
    orig = ex["text"]
    dsft_sources.add(orig)
    pars = ex["paraphrases"]
    if isinstance(pars, str):
        pars = ast.literal_eval(pars)
    for para in pars:
        pairs.append({"sentence": orig, "paraphrase": para})

assert len(pairs) == 100_000, "DSFT size should be exactly 100 k"  # sanity

dsft_path = WORKDIR / "DSFT_100k.jsonl"
with open(dsft_path, "w", encoding="utf-8") as f:
    for p in pairs:
        f.write(json.dumps(p, ensure_ascii=False) + "\n")

print("DSFT saved", dsft_path)

#### (ii) Tokenization

In [None]:
dsft_path = WORKDIR / "DSFT_100k.jsonl"


tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, token=HF_TOKEN)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"


def format_and_mask_chat(example):
    chat = [
        {"role": "system", "content": SYSTEM_MESSAGE},
        {
            "role": "user",
            "content": f"You will be given a sentence. Please paraphrase the sentence.\nSentence: {example['sentence']}",
        },
        {"role": "assistant", "content": example["paraphrase"]},
    ]

    tokenized_chat = tokenizer.apply_chat_template(
        chat, tokenize=True, add_generation_prompt=False, return_dict=True
    )
    input_ids = tokenized_chat["input_ids"]
    attention_mask = tokenized_chat["attention_mask"]

    # Assistant starts after prompt (with generation prompt)
    prompt_ids = tokenizer.apply_chat_template(
        chat[:-1], tokenize=True, add_generation_prompt=True
    )
    cut = len(prompt_ids)
    labels = [-100] * cut + input_ids[cut:]

    maxlen = tokenizer.model_max_length
    input_ids = input_ids[:maxlen]
    attention_mask = attention_mask[:maxlen]
    labels = labels[:maxlen]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


dsft = load_dataset("json", data_files=str(dsft_path))["train"]

tokenized_dsft = dsft.map(
    format_and_mask_chat,
    remove_columns=dsft.column_names,
)

print(f"Tokenized dataset created with {len(tokenized_dsft)} examples.")

#### (iii) Train LoRA

In [None]:
from typing import List, Dict


class ChatDataCollator:
    def __init__(self, tokenizer, padding="longest"):
        self.pad = DataCollatorWithPadding(tokenizer, padding=padding)

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        labels = [f.pop("labels") for f in features]

        batch = self.pad(features)

        max_len = batch["input_ids"].size(1)
        padded = [l + [-100] * (max_len - len(l)) for l in labels]
        batch["labels"] = torch.tensor(padded, dtype=torch.long)
        return batch


model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=HF_TOKEN,
)

lora_cfg = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_cfg)
model.config.use_cache = False
model.print_trainable_parameters()


training_args = TrainingArguments(
    output_dir=str(WORKDIR / "sft"),
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=1e-4,
    num_train_epochs=3,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_strategy="steps",
    logging_steps=100,
    report_to="tensorboard",
    bf16=True,
    optim="adamw_torch",
    save_strategy="epoch",
    save_total_limit=3,
)

collator = ChatDataCollator(tokenizer, padding="longest")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dsft,
    tokenizer=tokenizer,
    data_collator=collator,
)

last_checkpoint = get_last_checkpoint(training_args.output_dir)

print("Starting SFT training...")
trainer.train(resume_from_checkpoint=last_checkpoint)
print("SFT training finished.")

#### (iv) Merge LoRA

In [None]:
print("Merging LoRA adapter and saving the final model...")

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.float16,
    device_map="auto",
    token=HF_TOKEN,
)

from peft import PeftModel
from glob import glob

ckpts = sorted(
    glob(str(training_args.output_dir) + "/checkpoint-*"),
    key=lambda x: int(x.split("-")[-1]),
)
last_checkpoint_path = ckpts[-1] if ckpts else None
if last_checkpoint_path is None:
    raise ValueError("No SFT checkpoint found to merge.")

sft_model = PeftModel.from_pretrained(base_model, last_checkpoint_path)

sft_model = sft_model.merge_and_unload()

sft_merged_dir = WORKDIR / "sft_merged"
sft_model.save_pretrained(sft_merged_dir)
tokenizer.save_pretrained(sft_merged_dir)

print(f"Fine-tuned model saved to: {sft_merged_dir}")

## 2️⃣ Direct Preference Optimization (DPO)
– Build 50 k (prompt, chosen, rejected) preference triples **disjoint** from SFT.  
– Train a new LoRA on top of the SFT‑merged model.

#### (i) Build DDPO_50k

In [None]:
EMB_MODEL = "sentence-transformers/all-MiniLM-L6-v2"  # Instead of BERT-base CLS
embedder = SentenceTransformer(EMB_MODEL, device=DEVICE)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, token=HF_TOKEN)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

raw_ddpo = [ex for ex in raw_all if ex["text"] not in dsft_sources]
raw_ddpo = random.Random(321).sample(raw_ddpo, 50_000)

prefs = []
BATCH = 64
for i in tqdm(range(0, len(raw_ddpo), BATCH)):
    chunk = raw_ddpo[i : i + BATCH]
    sentences = [ex["text"] for ex in chunk]
    paraphrase_lists = [ast.literal_eval(ex["paraphrases"]) for ex in chunk]

    # Flatten for one-shot embed
    flat = []
    for src, plist in zip(sentences, paraphrase_lists):
        flat.append(src)
        flat.extend(plist)
    embs = F.normalize(
        embedder.encode(flat, convert_to_tensor=True, device=DEVICE), p=2, dim=1
    )
    idx = 0
    for src, plist in zip(sentences, paraphrase_lists):
        src_emb = embs[idx]
        par_embs = embs[idx + 1 : idx + 1 + len(plist)]
        idx += 1 + len(plist)
        dists = 1 - (par_embs @ src_emb)  # cosine distance
        iw, il = dists.argmax().item(), dists.argmin().item()
        chosen, rejected = plist[iw], plist[il]

        prompt = tokenizer.apply_chat_template(
            [
                {"role": "system", "content": SYSTEM_MESSAGE},
                {
                    "role": "user",
                    "content": f"Paraphrase the following sentence:\n{src}",
                },
            ],
            tokenize=False,
            add_generation_prompt=True,
        )
        for txt in (chosen, rejected):
            if not txt.endswith("\n<|im_end|>"):
                txt += "\n<|im_end|>"
        prefs.append({"prompt": prompt, "chosen": chosen, "rejected": rejected})


ddpo_path = WORKDIR / "DDPO_50k.jsonl"
with ddpo_path.open("w", encoding="utf-8") as f:
    for p in prefs:
        f.write(json.dumps(p, ensure_ascii=False) + "\n")

print("DDPO saved. Size:", len(prefs))

#### (ii) Prepare model & LoRA

In [None]:
ddpo_path = WORKDIR / "DDPO_50k.jsonl"
sft_dir = WORKDIR / "sft_merged"

model = AutoModelForCausalLM.from_pretrained(
    sft_dir,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_auth_token=HF_TOKEN,
)

lora_cfg_dpo = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_cfg_dpo)
model.config.use_cache = False

ref_model = AutoModelForCausalLM.from_pretrained(
    sft_dir, torch_dtype=torch.bfloat16, device_map="auto", token=HF_TOKEN
)

#### (iii) Train DPO

In [None]:
ddpo = load_dataset("json", data_files=str(ddpo_path))["train"]

dpo_config = DPOConfig(
    output_dir=str(WORKDIR / "dpo"),
    max_length=256,
    max_prompt_length=128,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=5e-6,
    num_train_epochs=3,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=100,
    bf16=True,
    save_strategy="epoch",
    save_total_limit=3,
    beta=0.1,
    report_to="tensorboard",
)

dpo_trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    args=dpo_config,
    train_dataset=ddpo,
    tokenizer=tokenizer,
)

last_checkpoint_path = get_last_checkpoint(dpo_config.output_dir)
dpo_trainer.train(resume_from_checkpoint=last_checkpoint_path)

#### (iv) Merge and save

In [None]:
final_model = model.merge_and_unload()
final_dir = WORKDIR / "doaug_paraphraser"

final_model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)

print("DPO‑finished model saved to", final_dir)

3️⃣ Quick Inference Check
---

In [21]:
from transformers import pipeline

paraphraser = pipeline(
    "text-generation", model=str(final_dir), tokenizer=tokenizer, device=DEVICE
)


def paraphrase(sentence: str):
    prompt = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": SYSTEM_MESSAGE},
            {
                "role": "user",
                "content": f"You will be given a sentence. Please paraphrase the sentence.\nSentence: {sentence}",
            },
        ],
        tokenize=False,
        add_generation_prompt=True,
    )
    out = paraphraser(
        prompt,
        max_new_tokens=64,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        eos_token_id=tokenizer.eos_token_id,
    )[0]["generated_text"]
    return (
        out.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1]
        .split("<|im_end|>")[0]
        .strip()
    )


print("Test:")
print(paraphrase("""A single candle lit the dark, quiet room."""))

Test:
Ambiance was restored to the space with the flicker of a flame.
