# Phase 2: Model Training

In this phase, we will fine-tune the Flan-T5 model using LoRA (Low-Rank Adaptation) to improve its performance for text summarization. We will use the HuggingFace Transformers library to set up the training pipeline.

## Steps:

1. **Load the Base Model**

2. **Integrate LoRA for Fine-Tuning**

3. **Configure the Training Pipeline**

4. **Fine-Tune the Model**

This phase focuses on leveraging parameter-efficient fine-tuning via LoRA to adapt the Flan-T5 model for improved summarization performance.

---

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model

import torch

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### Load the Flan-T5-base model

In [None]:
model_name='google/flan-t5-base'

base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
base_model.to(device)

### Define the tokenizer, to decode the output of the model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

### Load the dataset

In [None]:
import pickle
with open('data/dataset_t5_base.pkl', 'rb') as file:
    dataset = pickle.load(file)

### Setup the PEFT configuration

In [None]:
# Set up LoRA configuration
lora_config = LoraConfig(
    r=24,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)

# Wrap model with LoRA
peft_model = get_peft_model(base_model, lora_config)

In [None]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"


In [None]:
print(print_number_of_trainable_model_parameters(peft_model))

### Create the data collator, set the training parameters and create the trainer

In [None]:
# Define the data collator to handle padding dynamically. For the moment, the dataset is composed of lists of variable lenght.
data_collator = DataCollatorForSeq2Seq(tokenizer, model=peft_model)

In [None]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir="./results_training",
    report_to="none",  # Disable logging to W&B
    evaluation_strategy="steps",
    eval_steps=100,
    learning_rate=1e-3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    logging_steps=100,
)

In [None]:
# Get LoRA trainable parameters
lora_parameters = [p for p in peft_model.parameters() if p.requires_grad]

# Define optimizer
optimizer = torch.optim.AdamW(lora_parameters, lr=1e-3)

In [None]:
# Initialize the Trainer
trainer = Trainer(
    model=peft_model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    optimizers=(optimizer, None)
)

In [None]:
# Start fine-tuning
trainer.train()

In [None]:
model_path="./peft_model_trained_google_flan_t5_base_dialogue_summarization"

trainer.model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
