In [2]:
import os, inspect, random, re
import numpy as np
import pandas as pd
from typing import Dict, Any, List

import torch
import torch.nn as nn
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer,
    default_data_collator,
    set_seed,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# -----------------------------
# 0) 환경 변수 (Windows/HF)
# -----------------------------
os.environ["HF_HUB_DISABLE_XET"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS"] = "1"
os.environ["HF_HUB_ENABLE_TQDM_MULTIPROCESSING"] = "0"
os.environ["HF_HOME"] = os.environ.get("HF_HOME", r"C:\hf_cache_clean")

SEED = 42
set_seed(SEED)

# -----------------------------
# 1) 설정
# -----------------------------
DATA_CSV = "라벨완성본+일상대화_확장.csv"
OUTPUT_DIR = "./korsmishing-qlora"
SAVE_DIR   = "./korsmishing-qlora-smishing-expl_확장"

BASE_MODEL = os.environ.get("BASE_MODEL", "LGAI-EXAONE/EXAONE-4.0-1.2B")
common_load = dict(trust_remote_code=True)

# 학습 하이퍼파라미터
N_EPOCHS = 5
LR = 2e-5
PER_TRAIN_BS = 4
PER_EVAL_BS  = 4
GRAD_ACCUM   = 1

# QLoRA
LORA_R = 8
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
TARGET_MODULES_PRIORITY = ["query_key_value"]

# 길이/생성
MAX_LEN = 512
MAX_TARGET_TOKENS = 256
GEN_MAX_NEW_TOKENS = 128

# EXAONE 권장 설정
GEN_TEMPERATURE = 0.1
GEN_REP_PENALTY = 1.1

# 불균형 가중손실
POS_NAME, NEG_NAME = "스미싱", "정상"
TARGET_POS_RATIO = 0.5
WEIGHT_SCALE = 50.0

# -----------------------------
# 2) TrainingArguments 호환 래퍼
# -----------------------------
def make_training_args(**kw):
    params = set(inspect.signature(TrainingArguments.__init__).parameters.keys())
    out = {}
    if "evaluation_strategy" in kw:
        if "evaluation_strategy" in params:
            out["evaluation_strategy"] = kw["evaluation_strategy"]
        elif "eval_strategy" in params:
            out["eval_strategy"] = kw["evaluation_strategy"]
    if "save_strategy" in kw and "save_strategy" in params:
        out["save_strategy"] = kw["save_strategy"]
    for k, v in kw.items():
        if k in ("evaluation_strategy", "save_strategy"):
            continue
        if k in params:
            out[k] = v
    return TrainingArguments(**out)

def supports_bf16():
    if not torch.cuda.is_available():
        return False
    major, _ = torch.cuda.get_device_capability(0)
    return major >= 8

# -----------------------------
# 3) 데이터 로드/정규화
# -----------------------------
df = pd.read_csv(DATA_CSV)

for col in ["content", "label", "explanation"]:
    if col not in df.columns:
        raise ValueError(f"CSV에는 '{col}' 컬럼이 있어야 합니다.")

label_map = {0: "정상", 1: "스미싱", "normal": "정상", "smishing": "스미싱"}
df["label"] = df["label"].map(label_map).fillna(df["label"])
df = df[df["label"].isin([POS_NAME, NEG_NAME])].copy()

# -----------------------------
# 4) 프롬프트/타깃 구성 (수정됨: 간결한 지시 포함)
# -----------------------------
def make_prompt(row):
    system_msg = "너는 문자를 분석하여 스미싱 여부를 판단하고, 그 이유를 설명하는 보안 전문가야."
    
    # 수정: 중복 제거 및 간결한 설명 지시 추가
    user_msg = (
        "다음 $$문자$$를 보고 먼저 $$스미싱 여부$$를 판단한 뒤, 그 이유를 한두 문장으로 제시하세요.\n"
        "형식:\n"
        "$$스미싱 여부$$: (스미싱/정상)\n"
        "$$설명$$: (간단한 설명)\n\n"
        f"$$문자$$\n{row['content']}"
        "<답변>\n"
    )
    
    return (
        f"[|system|]{system_msg}[|endofturn|]\n"
        f"[|user|]{user_msg}[|endofturn|]\n"
        "[|assistant|]"
    )

def make_target(row):
    return f"{row['explanation']}[|endofturn|]"

df = df.sample(frac=1, random_state=SEED).reset_index(drop=True)
df_train, df_eval = train_test_split(df, test_size=0.2, stratify=df["label"], random_state=SEED)

for d in (df_train, df_eval):
    d["prompt"] = d.apply(make_prompt, axis=1)
    d["target"] = d.apply(make_target, axis=1)

# 가중치 계산
N_pos = int((df_train["label"] == POS_NAME).sum())
N_neg = int((df_train["label"] == NEG_NAME).sum())
w_pos = TARGET_POS_RATIO / max(1, N_pos)
w_neg = (1 - TARGET_POS_RATIO) / max(1, N_neg)

df_train["example_weight"] = df_train["label"].map({POS_NAME: w_pos, NEG_NAME: w_neg}).astype("float32") * WEIGHT_SCALE
df_eval["example_weight"]  = 1.0

train_ds = Dataset.from_pandas(df_train[["prompt","target","label","explanation","example_weight"]].reset_index(drop=True))
eval_ds  = Dataset.from_pandas(df_eval[ ["prompt","target","label","explanation","example_weight"]].reset_index(drop=True))
raw_dataset = DatasetDict({"train": train_ds, "eval": eval_ds})

# -----------------------------
# 5) 토크나이저/모델/QLoRA
# -----------------------------
USE_4BIT = True
try:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=(
            torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8)
            else torch.float16
        ),
        bnb_4bit_use_double_quant=True,
    )
except Exception as e:
    print("[WARN] bitsandbytes unavailable:", e)
    USE_4BIT = False
    bnb_config = None

def load_tokenizer():
    tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, **common_load)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    tok.padding_side = "right"
    return tok

tokenizer = load_tokenizer()
PAD_ID = tokenizer.pad_token_id

def load_base_model():
    kwargs = dict(**common_load)
    if torch.cuda.is_available():
        kwargs["device_map"] = "auto"

    if USE_4BIT and bnb_config is not None and torch.cuda.is_available():
        kwargs["quantization_config"] = bnb_config
        model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, **kwargs)
    else:
        dtype = (
            torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8)
            else (torch.float16 if torch.cuda.is_available() else torch.float32)
        )
        kwargs["torch_dtype"] = dtype
        model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, **kwargs)

    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    return model

model = load_base_model()

def pick_target_modules(m: nn.Module) -> List[str]:
    names = [n for n, _ in m.named_modules()]
    for pref in TARGET_MODULES_PRIORITY:
        if any(n.endswith(pref) for n in names):
            return [pref]
    return ["q_proj","k_proj","v_proj","o_proj"]

target_modules = pick_target_modules(model)
print("[LoRA] target_modules =", target_modules)

peft_cfg = LoraConfig(
    r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
    bias="none", task_type="CAUSAL_LM", target_modules=target_modules
)
model = get_peft_model(model, peft_cfg)
model.print_trainable_parameters()

# -----------------------------
# 6) 토크나이즈 및 패딩 (검증 출력 제거됨)
# -----------------------------
MAX_TARGET = MAX_TARGET_TOKENS
IGNORE_INDEX = -100

def _trim_to_max(ids, n):
    return ids if len(ids) <= n else ids[:n]

def build_item(ex: Dict[str, Any]) -> Dict[str, Any]:
    prompt_ids = tokenizer(ex["prompt"], add_special_tokens=True, padding=False, truncation=False)["input_ids"]
    target_ids = tokenizer(ex["target"], add_special_tokens=False, padding=False, truncation=False)["input_ids"]
    target_ids = _trim_to_max(target_ids, MAX_TARGET)

    max_prompt_len = MAX_LEN - len(target_ids)
    if max_prompt_len <= 0:
        keep_target = max(8, MAX_LEN // 2)
        target_ids = target_ids[:keep_target]
        max_prompt_len = MAX_LEN - len(target_ids)

    prompt_trim = prompt_ids[-max_prompt_len:]
    input_ids = (prompt_trim + target_ids)[:MAX_LEN]
    labels    = [IGNORE_INDEX]*len(prompt_trim) + target_ids
    labels    = labels[:MAX_LEN]
    attn      = [1]*len(input_ids)

    pad_len = MAX_LEN - len(input_ids)
    if pad_len > 0:
        input_ids += [PAD_ID]*pad_len
        labels    += [IGNORE_INDEX]*pad_len
        attn      += [0]*pad_len

    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attn,
        "example_weight": float(ex["example_weight"]),
    }

proc_dataset = raw_dataset.map(build_item, remove_columns=raw_dataset["train"].column_names)

# -----------------------------
# 6.5) 라벨 부스트 (끄기)
# -----------------------------
LABEL_BOOST = 1.0        # 1.0 = 부스트 끄기
LABEL_HEAD_TOKENS = 10   

# -----------------------------
# 7) 가중 손실 Trainer
# -----------------------------
class WeightedLossTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        weights = inputs.pop("example_weight", None)
        outputs = model(**inputs, labels=None)
        logits = outputs.logits

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(
            reduction="none",
            ignore_index=IGNORE_INDEX,
        )
        vocab = shift_logits.size(-1)
        token_loss = loss_fct(
            shift_logits.view(-1, vocab),
            shift_labels.view(-1)
        ).view(shift_labels.size(0), -1)

        base_mask = (shift_labels != IGNORE_INDEX).float()

        if LABEL_BOOST != 1.0:
            T = shift_labels.size(1)
            boost_mask = torch.zeros_like(base_mask)
            non_ign = (shift_labels != IGNORE_INDEX).float()
            target_start = non_ign.argmax(dim=1)
            for i in range(base_mask.size(0)):
                s = int(target_start[i].item())
                e = min(s + LABEL_HEAD_TOKENS, T)
                boost_mask[i, s:e] = 1.0
            eff_mask = base_mask * (1.0 + (LABEL_BOOST - 1.0) * boost_mask)
        else:
            eff_mask = base_mask

        ex_loss = (token_loss * eff_mask).sum(dim=1) / eff_mask.sum(dim=1).clamp_min(1.0)

        if weights is not None:
            if not torch.is_tensor(weights):
                weights = torch.tensor(weights, dtype=ex_loss.dtype, device=ex_loss.device)
            else:
                weights = weights.to(ex_loss.device, dtype=ex_loss.dtype)
            ex_loss = ex_loss * weights

        loss = ex_loss.mean()
        return (loss, outputs) if return_outputs else loss

# -----------------------------
# 8) TrainingArguments
# -----------------------------
use_bf16 = supports_bf16()
optim_name = "paged_adamw_8bit" if USE_4BIT else "adamw_torch"

args = make_training_args(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=PER_TRAIN_BS,
    per_device_eval_batch_size=PER_EVAL_BS,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=N_EPOCHS,
    learning_rate=LR,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    weight_decay=0.01,
    max_grad_norm=1.0,
    logging_steps=50,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=200,              
    save_steps=200,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    bf16=use_bf16,
    fp16=not use_bf16,
    optim=optim_name,
    gradient_checkpointing=True,
    report_to=[],
    dataloader_num_workers=0,
    remove_unused_columns=False,
)

trainer = WeightedLossTrainer(
    model=model,
    args=args,
    train_dataset=proc_dataset["train"],
    eval_dataset=proc_dataset["eval"],
    data_collator=default_data_collator,
)

# 학습 시작
print("=== Training Started ===")
trainer.train()

# 학습 종료 후 저장
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
print(f"=== Training Finished & Saved to {SAVE_DIR} ===")

[LoRA] target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj']
trainable params: 3,194,880 || all params: 1,282,586,368 || trainable%: 0.2491


Map:   0%|          | 0/14616 [00:00<?, ? examples/s]

Map:   0%|          | 0/3654 [00:00<?, ? examples/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


=== Training Started ===


Step,Training Loss,Validation Loss
200,0.005,1.165405
400,0.0029,0.634275
600,0.002,0.473565
800,0.0018,0.411365
1000,0.0013,0.373438
1200,0.0015,0.351493
1400,0.0014,0.328754
1600,0.0012,0.316918
1800,0.0012,0.303702
2000,0.0011,0.294002


=== Training Finished & Saved to ./korsmishing-qlora-smishing-expl_확장 ===
