https://blog.devgenius.io/sculpting-language-gpt-2-fine-tuning-with-lora-1caf3bfbc3c6

In [1]:
import math
import json
import torch
import transformers
from typing import Literal
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments

# Converting Dataset Into Training Dataset 

In [2]:
data_location = "dataset/data.json"
data_for_training: dict[Literal["train", "validation"], str] = {
    "train": "data/train_data.json",
    "validation": "data/val_data.json"
}

# Training

## Configuration Before Training

In [3]:
cache_dir = "models"
modelID = "openai-community/gpt2"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(modelID, cache_dir=cache_dir)

# Set padding token
tokenizer.padding_side = "right"                # Set padding side to right
tokenizer.pad_token = tokenizer.eos_token      # Using eos_token as pad_token

model = AutoModelForCausalLM.from_pretrained(modelID, device_map='auto', cache_dir=cache_dir)

model.config.pad_token_id = tokenizer.pad_token_id  # Set the pad_token_id in the model config

In [5]:
# FREEZE WEIGHTS
for param in model.parameters():
    param.requires_grad = False

In [None]:
lora_config = LoraConfig(
    r=16,                      # Rank
    lora_alpha=32,            # Alpha parameter for LoRA
    lora_dropout=0.05,         # Dropout for LoRA
    bias="none",              # Choose bias (none, all, or lora)
    task_type=TaskType.CAUSAL_LM,  # Set to Causal Language Modeling
)

# lora_config = LoraConfig(
#     r=32,                      # Rank
#     lora_alpha=16,            # Alpha parameter for LoRA
#     lora_dropout=0.1,         # Dropout for LoRA
#     bias="none",              # Choose bias (none, all, or lora)
#     task_type=TaskType.CAUSAL_LM,  # Set to Causal Language Modeling
# )
model = get_peft_model(model, lora_config)

In [7]:
# Load train, validation, and test datasets
dataset = load_dataset('json', data_files={
    'train': data_for_training["train"],
    'validation': data_for_training["validation"],
    # 'test': 'data/test_data.json'
})
shuffled_dataset = dataset.shuffle(seed=42, keep_in_memory=True)
del dataset

In [None]:
# Load your training data from train_data.json
with open('data/train_data.json', 'r') as file:
    train_data = json.load(file)


# Get lengths of tokenized texts
lengths = [len(tokenizer(obj['text'])['input_ids']) for obj in train_data]

def next_power_of_2(n):
    if n < 1:
        raise ValueError("Input must be a positive integer.")

    # Calculate the power of 2 using logarithm
    power = math.ceil(math.log2(n))  # Get the smallest integer >= log2(n)
    
    # Return 2 raised to the calculated power
    return 2 ** power

print("Maximum length:", max(lengths))
max_length = next_power_of_2(max(lengths))
print("max_length:",max_length)

In [10]:
# max_length = 128
max_length = 64

In [11]:
# Tokenize the dataset and create labels
def tokenize_function(examples):
    tokenized = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=max_length)
    # Create labels (shifted input for language modeling)
    tokenized['labels'] = tokenized['input_ids'].copy()
    return tokenized

In [None]:
# Tokenize the dataset
tokenized_datasets = shuffled_dataset.map(tokenize_function, batched=True)

In [13]:
# Set format for PyTorch
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [14]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
print_trainable_parameters(model)

In [16]:
batch = 8
training_args = TrainingArguments(
    per_device_train_batch_size=batch,
    gradient_accumulation_steps=batch,
    warmup_steps=10,
    # max_steps=500,
    num_train_epochs=4,#20, 
    learning_rate=2e-4,
    logging_steps=30,
    output_dir='outputs',
    # auto_find_batch_size=True,
    dataloader_drop_last=False # Don't drop the last incomplete batch (optional)
)


# batch = 2
# training_args = TrainingArguments(
#     per_device_train_batch_size=batch,
#     gradient_accumulation_steps=batch,
#     warmup_steps=10,
#     # max_steps=500,
#     num_train_epochs=12,#20, 
#     learning_rate=1e-4,
#     logging_steps=batch*2,
#     output_dir='outputs',
#     auto_find_batch_size=True
# )

In [17]:
# Initialize Trainer
trainer = Trainer(
    model=model,                        
    args=training_args,                 
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],  # Use validation set here
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

In [18]:
model.config.use_cache = False

In [19]:
import matplotlib.pyplot as plt

# Initialize lists to store losses
train_losses = []
eval_losses = []

class CustomCallback(transformers.TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        # Get the training loss
        train_loss = trainer.state.log_history[-1]["loss"]
        train_losses.append(train_loss)
        
        # Get the evaluation loss
        eval_loss = trainer.evaluate()["eval_loss"]
        eval_losses.append(eval_loss)

# Add the custom callback to the Trainer
trainer.add_callback(CustomCallback())

In [None]:
trainer.train()

In [None]:
# Plot the losses
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Training Loss')
plt.plot(range(1, len(eval_losses) + 1), eval_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss per Epoch')
plt.legend()
plt.grid()
plt.show()

In [23]:
torch.save(model.state_dict(), 'lora.pt')