# Fine-Tuning Google's FLAN-T5 model for Topic Labeling

by Andreas Sünder

## Install required packages (only once)

```bash
%pip install -r requirements.txt
```

## Setup

Open up a terminal and run the following commands:

```bash
huggingface-cli login
wandb login
```

Set up the following variables:

In [None]:
DATASET_NAME = 'textminr/topic-labeling'
MODEL_NAME = 'google/flan-t5-xl'

PROJECT_NAME = 'tl_qlora_flan-t5-xl'
%env WANDB_PROJECT=$PROJECT_NAME

Load the dataset:

In [None]:
from datasets import load_dataset, concatenate_datasets

dataset = load_dataset(DATASET_NAME)
dataset = dataset.rename_column('label', 'topic_label')

print(f"Train dataset size: {dataset['train'].num_rows}")
# print(f"Test dataset size: {dataset['validation'].num_rows}")

Define a prompt template:

In [None]:
prompt_template = 'Provide a topic label: {}'

## Load the model and tokenizer

In [None]:
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer

bnb_config = BitsAndBytesConfig(
  load_in_4bit=True,
  bnb_4bit_use_double_quant=True,
  bnb_4bit_quant_type='nf4',
  bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, quantization_config=bnb_config, torch_dtype=torch.bfloat16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

## Prepare data

In [None]:
max_source_length = 130
max_target_length = 30

def preprocess_data(sample, padding: str = 'max_length'):
  model_inputs = tokenizer(
    [prompt_template.format(top_terms) for top_terms in sample['top_terms']],
    truncation=True,
    padding='max_length',
    max_length=max_source_length
  )

  labels = tokenizer(
    text_target=[label for label in sample['topic_label']],
    truncation=True,
    padding='max_length',
    max_length=max_target_length
  )

  if padding == 'max_length':
    labels['input_ids'] = [
      [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels['input_ids']
    ]

  model_inputs['labels'] = labels['input_ids']
  return model_inputs

tokenized_dataset = dataset.map(preprocess_data, batched=True)

## Setup LoRa

In [None]:
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)

In [None]:
from peft import LoraConfig, TaskType, get_peft_model

config = LoraConfig(
  r=4,
  lora_alpha=16,
  target_modules=['q', 'k', 'v', 'o'],
  bias='none',
  lora_dropout=0.05,
  task_type=TaskType.SEQ_2_SEQ_LM
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

## DataCollator

In [None]:
from transformers import DataCollatorForSeq2Seq

label_pad_token_id = -100
data_collator = DataCollatorForSeq2Seq(
  tokenizer,
  model=model,
  label_pad_token_id=label_pad_token_id,
  pad_to_multiple_of=8
)

## Run training

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datetime import datetime

training_args = Seq2SeqTrainingArguments(
  output_dir=f'models/{PROJECT_NAME}',
  per_device_train_batch_size=8,
  per_device_eval_batch_size=8,
  predict_with_generate=True,
  optim='paged_adamw_8bit',
  bf16=True,
  num_train_epochs=1,
  learning_rate=1e-3,
  logging_steps=10,
  logging_dir='./logs',
  save_strategy='no',
  # do_eval=True,
  # evaluation_strategy='steps',
  # eval_steps=200,
  report_to='wandb',
  run_name=f'{PROJECT_NAME}-{datetime.now().strftime("%Y-%m-%d-%H-%M")}'
)

trainer = Seq2SeqTrainer(
  model=model,
  args=training_args,
  train_dataset=tokenized_dataset['train'],
  # eval_dataset=tokenized_dataset['validation'],
  data_collator=data_collator,
  tokenizer=tokenizer,
)

trainer.train()

In [None]:
model.push_to_hub('textminr/tl-flan-t5-xl')