-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fine_tune_llm.py
58 lines (50 loc) · 1.53 KB
/
fine_tune_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@name: fine_tune_llm.py
@author: Finbarrs Oketunji
@contact: f@finbarrs.eu
@time: Sunday January 14 23:22:00 2024
@desc: LLM Fine-tuning.
@run: python3 fine_tune_llm.py
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
# Load the pre-trained model and tokenizer
model_name = "gpt2-medium" # Replace with the desired pre-trained model
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load and preprocess the dataset
train_file = "dataset/train.txt" # Replace with your training dataset file
train_dataset = TextDataset(
tokenizer=tokenizer,
file_path=train_file,
block_size=128, # Adjust the block size as needed
)
# Define the data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # Set to True if using masked language modeling
)
# Define the training arguments
training_args = TrainingArguments(
output_dir="output",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=4,
save_steps=10_000,
save_total_limit=2,
prediction_loss_only=True,
)
# Create the Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
)
# Fine-tune the model
trainer.train()
# Save the fine-tuned model
trainer.save_model("gpt2_medium_fine_tuned_model")