## Fine tuning LLM
This notebook shows how to fine tune LLM model. Fine-tuning is advanced technique which can modify model parameters to perform better in specific task. 

In [None]:
%pip install \
datasets \
transformers \
evaluate \
torch \
torchdata \
rouge_score \
loralib \
peft

In [1]:
import torch
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer

2024-09-15 17:40:22.676461: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-15 17:40:22.720989: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-15 17:40:22.734283: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-15 17:40:22.808849: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
model_name='google/flan-t5-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset_train = load_dataset("gretelai/synthetic_text_to_sql", split='train[0%:80%]')
dataset_valid = load_dataset("gretelai/synthetic_text_to_sql", split='train[80%:100%]')
dataset_test = load_dataset("gretelai/synthetic_text_to_sql", split='test')

dataset_train



Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 80000
})

### Check results with zero-shot

In [3]:
example_indexes = [37, 67]
for i, index in enumerate(example_indexes):
    schema = dataset_test[index]['sql_context']
    instruction = dataset_test[index]['sql_prompt']
    sql = dataset_test[index]['sql']

    prompt = f"""
Translate schema and description into SQL query. Respond only with SQL query without any additional characters.

Schema: {schema}

Instruction: {instruction}

SQL query:
    """

    inputs = tokenizer(prompt, return_tensors='pt')
    output = tokenizer.decode(
        model.generate(
            inputs["input_ids"],
            max_new_tokens=50,
        )[0],
        skip_special_tokens=True
    )

    print('Example ', i + 1)
    print(f'Task:\n{instruction}')
    print(f'Schema:\n{schema}')
    print(f'CORRECT SQL:\n{sql}')
    print(f'MODEL GENERATION SQL - ZERO SHOT:\n{output}\n')

Example  1
Task:
What is the number of medical supplies distributed by each organization, in East Africa, for the last 3 years, and the total cost of the supplies?
Schema:
CREATE TABLE medical_supplies (supply_id INT, organization_id INT, location VARCHAR(255), supply_type VARCHAR(255), supply_cost DECIMAL(10,2), distribution_date DATE); INSERT INTO medical_supplies VALUES (1, 1, 'Country A', 'Medicine', 5000, '2020-01-01'); INSERT INTO medical_supplies VALUES (2, 1, 'Country A', 'Medical Equipment', 7000, '2021-01-01'); INSERT INTO medical_supplies VALUES (3, 2, 'Country B', 'Vaccines', 10000, '2021-01-01'); INSERT INTO medical_supplies VALUES (4, 2, 'Country B', 'First Aid Kits', 8000, '2020-01-01');
CORRECT SQL:
SELECT organization_id, location as region, COUNT(*) as number_of_supplies, SUM(supply_cost) as total_supply_cost FROM medical_supplies WHERE location = 'East Africa' AND distribution_date >= DATE_SUB(CURRENT_DATE, INTERVAL 3 YEAR) GROUP BY organization_id, location;
MODEL G

### Conduct model fine tuning process
Fine tune model using HuggingFace Trainer class

In [4]:
# Preprocess dataset
def tokenize(row):
    prompt = f"""
Translate schema and description into SQL query. Respond only with SQL query without any additional characters.

Schema: {row["sql_context"]}

Instruction: {row["sql_prompt"]}

SQL query:
    """
    row['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids.squeeze(0)
    row['labels'] = tokenizer(row["sql"], padding="max_length", truncation=True,return_tensors="pt").input_ids.squeeze(0)

    return row
columns = dataset_train.column_names
tokenized_datasets_train = dataset_train.map(tokenize)
tokenized_datasets_train = tokenized_datasets_train.remove_columns(columns)
print("Train:")
print(tokenized_datasets_train.shape)
print(tokenized_datasets_train)

tokenized_datasets_valid = dataset_valid.map(tokenize)
tokenized_datasets_valid = tokenized_datasets_valid.remove_columns(columns)
print("Valid:")
print(tokenized_datasets_valid.shape)
print(tokenized_datasets_valid)

tokenized_datasets_test = dataset_test.map(tokenize)
tokenized_datasets_test = tokenized_datasets_test.remove_columns(columns)
print("Test:")
print(tokenized_datasets_test.shape)
print(tokenized_datasets_test)

Map:   0%|          | 0/80000 [00:00<?, ? examples/s]

Train:
(80000, 2)
Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 80000
})
Valid:
(20000, 2)
Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 20000
})
Test:
(5851, 2)
Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 5851
})


In [5]:
print(tokenized_datasets_train[0])

{'input_ids': [30355, 15, 26622, 11, 4210, 139, 12558, 11417, 5, 23483, 163, 28, 12558, 11417, 406, 136, 1151, 2850, 5, 10248, 51, 9, 10, 205, 4386, 6048, 332, 17098, 1085, 6075, 41, 7, 4529, 6075, 834, 23, 26, 3, 13777, 6, 564, 3, 3463, 4, 382, 6, 1719, 3, 3463, 4, 382, 3670, 3, 14750, 24203, 3388, 5647, 1085, 6075, 41, 7, 4529, 6075, 834, 23, 26, 6, 564, 6, 1719, 61, 3, 21712, 5078, 134, 4077, 6, 3, 31, 18300, 531, 15, 31, 6, 3, 31, 22969, 31, 201, 4743, 6, 3, 31, 683, 152, 15, 3931, 31, 6, 3, 31, 22081, 31, 3670, 205, 4386, 6048, 332, 17098, 14592, 834, 7, 4529, 41, 7, 4529, 834, 23, 26, 3, 13777, 6, 1085, 6075, 834, 23, 26, 3, 13777, 6, 2908, 17833, 6, 1048, 834, 5522, 309, 6048, 3670, 3, 14750, 24203, 3388, 5647, 14592, 834, 7, 4529, 41, 7, 4529, 834, 23, 26, 6, 1085, 6075, 834, 23, 26, 6, 2908, 6, 1048, 834, 5522, 61, 3, 21712, 5078, 134, 4077, 6, 1914, 5864, 6, 3, 31, 1755, 2658, 14772, 14772, 31, 201, 4743, 6, 1914, 4261, 6, 3, 31, 19818, 18930, 357, 14772, 31, 201, 6918, 6, 35

In [6]:
# Perform fine tuning

output_dir = f'./text-to-sql-translation'

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=1e-5,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_steps=1
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets_train,
    eval_dataset=tokenized_datasets_valid
)

In [None]:
trainer.train()

### Test fine-tuned model 

In [None]:
tuned_model = AutoModelForSeq2SeqLM.from_pretrained("./text-to-sql-translation", torch_dtype=torch.bfloat16)

In [None]:
index = 78
schema = dataset_test[index]['sql_context']
instruction = dataset_test[index]['sql_prompt']
sql = dataset_test[index]['sql']

prompt = f"""
Translate schema and description into SQL query. Respond only with SQL query without any additional characters.

Schema: {schema}

Instruction: {instruction}

SQL query:
    """

inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

tuned_output = tokenizer.decode(
    tuned_model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

print(f'Task:\n{instruction}')
print(f'Schema:\n{schema}')
print(f'CORRECT SQL:\n{sql}')
print(f'MODEL GENERATION SQL - ZERO SHOT:\n{output}\n')
print(f'MODEL GENERATION SQL - FINE TUNED:\n{tuned_output}\n')

### PEFT / LoRA
Fine tune model using PEFT/LoRA techniques

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

peft_model = get_peft_model(model, lora_config)

In [None]:
output_dir = f'./peft-text-to-sql-translation'

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning.
    num_train_epochs=1,
    logging_steps=1,
    max_steps=1
)

peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_datasets_train,
)

In [None]:
peft_trainer.train()

peft_model_path="./peft-text-to-sql-translation"

peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

In [None]:
from peft import PeftModel, PeftConfig

peft_model_base = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

peft_model = PeftModel.from_pretrained(peft_model_base,
                                       './peft-text-to-sql-translation',
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=False)

### Evaluate PEFT fine tuned model

In [None]:
index = 78
schema = dataset_test[index]['sql_context']
instruction = dataset_test[index]['sql_prompt']
sql = dataset_test[index]['sql']

prompt = f"""
Translate schema and description into SQL query. Respond only with SQL query without any additional characters.

Schema: {schema}

Instruction: {instruction}

SQL query:
    """

inputs = tokenizer(prompt, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

tuned_output = tokenizer.decode(
    tuned_model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

peft_output = tokenizer.decode(
    tuned_model.generate(
        inputs["input_ids"],
        max_new_tokens=50,
    )[0],
    skip_special_tokens=True
)

print(f'Task:\n{instruction}')
print(f'Schema:\n{schema}')
print(f'CORRECT SQL:\n{sql}')
print(f'MODEL GENERATION SQL - ZERO SHOT:\n{output}\n')
print(f'MODEL GENERATION SQL - FINE TUNED:\n{tuned_output}\n')
print(f'MODEL GENERATION SQL - PEFT/LoRA TUNED:\n{peft_output}\n')