In [None]:
%pip install trl

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

In [None]:
dataset = load_dataset("sahil2801/CodeAlpaca-20k", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

In [None]:
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

In [None]:
response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

In [None]:
# 체크포인트 설정
output_dir = "/tmp/clm-instruction-tuning"
save_strategy = "epoch"  # 또는 "steps" (500스텝마다 저장 시)

sft_args = SFTConfig(
    output_dir=output_dir,
    save_strategy=save_strategy,          # 에포크 종료 시 저장
    save_total_limit=2,             # 최신 2개 체크포인트 유지
    resume_from_checkpoint=True,    # 체크포인트 재개 활성화
    # SFT 전용 파라미터
    max_seq_length=512,
    packing=False
)

In [None]:
import os

trainer = SFTTrainer(
    model,
    args=sft_args,                 # SFTConfig 사용 유지
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator
)

# 체크포인트 존재 여부 확인
checkpoint_exists = os.path.exists(output_dir) and any(
    "checkpoint" in folder for folder in os.listdir(output_dir)
)

# 조건부 학습 재개
try:
    trainer.train(resume_from_checkpoint=checkpoint_exists)
except ValueError as e:
    if "No valid checkpoint" in str(e):
        print("체크포인트 없음. 새 학습 시작")
        trainer.train(resume_from_checkpoint=False)