In [None]:
import peft
import os
import sys
from peft import get_peft_model, LoraConfig, PeftModel, PeftConfig
import torch
import pandas as pd
import matplotlib.pyplot as plt
import pickle
from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, GPT2Tokenizer, TrainerCallback
current_dir = os.getcwd()
# Add the project root to sys.path (assuming src is in the root directory)
project_root = os.path.abspath(os.path.join(current_dir, os.pardir, os.pardir))
sys.path.append(project_root)
from src.data_processing.Formality_Transfer_Dataset import FormalityTransferDataset


In [None]:
# paths
test_path = os.path.join(project_root, 'data/processed/test.pkl')
train_path = os.path.join(project_root, 'data/processed/train.pkl')
tune_path = os.path.join(project_root, 'data/processed/tune.pkl')
tokeniser_path = os.path.join(project_root, 'src/models/tokenizer/tokenizer.pkl')
sys.path.append(os.path.join(project_root, 'src/data_processing'))
print(test_path)

In [None]:
# Load datasets
with open(test_path, 'rb') as f:
    test : FormalityTransferDataset = pickle.load(f)
with open(train_path, 'rb') as f:
    train : FormalityTransferDataset = pickle.load(f)
with open(tune_path, 'rb') as f:
    tune : FormalityTransferDataset = pickle.load(f)
with open(tokeniser_path, 'rb') as f:
    tokenizer : GPT2Tokenizer = pickle.load(f)
    print(len(tokenizer))

In [None]:
model = AutoModelForCausalLM.from_pretrained('gpt2-medium', device_map="auto")
model.resize_token_embeddings(len(tokenizer))

In [None]:
# Define LoRA Config
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["attn.c_attn", "attn.c_proj", "mlp.c_fc", "mlp.c_proj"],
    lora_dropout=0.05,
    bias="lora_only",
    task_type="CAUSAL_LM"
)

In [None]:
# Add LoRA adaptor
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
model.to('cpu') # if GPU is available later on, change to 'cuda'

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer,
    mlm=False,
    #pad_to_multiple_of=8
)

In [None]:
# Initialize lists to store loss values
training_loss = []
validation_loss = []

# Define a custom callback to track losses
class LossLoggerCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            if 'loss' in logs:
                training_loss.append(logs['loss'])
            if 'eval_loss' in logs:
                validation_loss.append(logs['eval_loss'])

In [None]:
training_args = TrainingArguments(
    output_dir="logs",
    per_device_train_batch_size=6,  # Lowered for memory
    per_device_eval_batch_size=5,   # Lowered for memory
    learning_rate=1e-5,
    num_train_epochs=6,
    logging_dir="logs/training",
    gradient_accumulation_steps=8,  # Adjust based on your needs
    logging_steps=500,
    save_strategy="epoch",
    save_steps=10_000,
    eval_strategy="steps",
    eval_steps=500,
    report_to="tensorboard",
    fp16=False  # Keep as False on MPS
)

In [None]:
# Create Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    callbacks=[LossLoggerCallback],
    data_collator=data_collator,
    train_dataset=train,
    eval_dataset=tune
)

# Disable cache for training
model.config.use_cache = False

In [None]:
# Train the model
trainer.train()

In [None]:
# Save our LoRA model & tokenizer results
lora_model_dir = "src/models/lora_trained"
base_model_dir = "src/models/base_model"
tokenizer_dir = "src/models/trained_tokenizer"

trainer.model.save_pretrained(lora_model_dir)
tokenizer.save_pretrained(tokenizer_dir)
trainer.model.base_model.save_pretrained(base_model_dir)

print("Training complete. Model saved.")

In [None]:
#Save loss data
loss_data = pd.DataFrame({
    'training_loss': training_loss,
    'validation_loss': validation_loss[:len(training_loss)]  # Ensure same length
})
loss_data.to_csv('loss_data_2.csv', index=False)

# Plot the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(training_loss, label='Training Loss')
plt.plot(validation_loss, label='Validation Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.savefig('loss_plot_2.png')
plt.show()