In [1]:
import json
from sklearn.model_selection import train_test_split
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, DatasetDict
from transformers.trainer_utils import get_last_checkpoint
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import EarlyStoppingCallback
import evaluate
import wandb

# üîπ 1. Îç∞Ïù¥ÌÑ∞ Î°úÎìú Î∞è Ï§ÄÎπÑ
# corpus.json ÌååÏùº Î°úÎìú
with open("corpus.json", "r", encoding="utf-8") as file:
    corpus_data = json.load(file)

# üîπ 2. ÌïôÏäµ/Í≤ÄÏ¶ù Îç∞Ïù¥ÌÑ∞ Î∂ÑÌï†
train_data, val_data = train_test_split(corpus_data, test_size=0.2, random_state=42)

# üîπ 3. Îç∞Ïù¥ÌÑ∞ÏÖã ÏÉùÏÑ±
dataset = DatasetDict({
    "train": Dataset.from_list(train_data),
    "validation": Dataset.from_list(val_data)
})

# üîπ 4. Î™®Îç∏ Î∞è ÌÜ†ÌÅ¨ÎÇòÏù¥Ï†Ä ÏÑ§Ï†ï
#model_name = "meta-llama/Llama-3.2-3B-Instruct"
model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# üîπ 5. Îç∞Ïù¥ÌÑ∞ Ìè¨Îß∑ÌåÖ Ìï®Ïàò
def formatting_prompts_func(example):
    return f"### Instruction: {example['input']}\n### Response: {example['output']}"

# üîπ 6. Îç∞Ïù¥ÌÑ∞ ÏΩúÎ†àÏù¥ÌÑ∞ ÏÑ§Ï†ï
response_template = "### Response:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

# üîπ 7. W&B ÏÑ§Ï†ï
wandb.init(project="gemma-instruction-tuning", name="google/gemma-2-2b-it")

# üîπ 8. ROUGE Î∞è BLEU Î©îÌä∏Î¶≠ Î°úÎìú
rouge_metric = evaluate.load("rouge")
bleu_metric = evaluate.load("bleu")

# üîπ 9. ÌèâÍ∞Ä Ìï®Ïàò Ï†ïÏùò
def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # üî∏ BLEU Ï†êÏàò Í≥ÑÏÇ∞
    predictions = [pred.split() for pred in decoded_preds]
    references = [[label.split()] for label in decoded_labels]
    bleu = bleu_metric.compute(predictions=predictions, references=references)

    # üî∏ ROUGE Ï†êÏàò Í≥ÑÏÇ∞
    rouge = rouge_metric.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )

    # üî∏ W&BÎ°ú Î©îÌä∏Î¶≠ Í∏∞Î°ù
    wandb.log({
        "bleu_score": bleu["bleu"],
        "rouge1": rouge["rouge1"].mid.fmeasure,
        "rouge2": rouge["rouge2"].mid.fmeasure,
        "rougeL": rouge["rougeL"].mid.fmeasure,
    })

    return {
        "bleu_score": bleu["bleu"],
        "rouge1": rouge["rouge1"].mid.fmeasure,
        "rouge2": rouge["rouge2"].mid.fmeasure,
        "rougeL": rouge["rougeL"].mid.fmeasure,
    }

#output_dir = "./finetuned_llama"
output_dir = "./finetuned_gemma"

# üîπ 10. SFT Trainer ÏÑ§Ï†ï
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    args=SFTConfig(
        output_dir=output_dir,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=10,
        logging_dir="./logs",
        logging_steps=100,
        save_steps=100,
        report_to=["wandb"]  # W&BÏóê Î°úÍ∑∏ Ï†ÑÏÜ°
    ),
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    compute_metrics=compute_metrics,  # Î©îÌä∏Î¶≠ Í≥ÑÏÇ∞ Ìï®Ïàò Ï∂îÍ∞Ä
)


trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=2))

# üîπ 11. ÌïôÏäµ ÏãúÏûë
last_checkpoint = get_last_checkpoint(output_dir)

if last_checkpoint is None:
    trainer.train()
else:
    trainer.train(resume_from_checkpoint=last_checkpoint)


wandb.finish()

OSError: You are trying to access a gated repo.
Make sure to have access to it at https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct.
401 Client Error. (Request ID: Root=1-67a6175f-29c63bd633f47257340e91e1;43e450ae-ac9e-44f4-9059-d6e64d9d6045)

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-3B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in.