# Instruction Finetuning using IA3

This notebook looks into how to perform instruction finetuning using IA3 PEFT method. The task is to perform Supervised finetuning (SFT) of Mistral for Natural language to SQL Query generation task.

In [None]:
import os
os.environ["WANDB_PROJECT"]="mistral_instruct_finetuning"

from enum import Enum
from functools import partial
import pandas as pd
import torch
import json

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig, set_seed
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import get_peft_model, IA3Config, TaskType

seed = 42
set_seed(seed)

### Data preprocessing

In [None]:
model_name = "ministral/Ministral-3b-instruct"
dataset_name = "wikisql"
def preprocess(sample):
    column_names = sample["table"]["header"]
    table_id = sample["table"]["id"]
    natural_query = sample["question"]
    sql_query = sample["sql"]["human_readable"].replace("table", table_id)
    content = f"Table: {table_id}\n Columns: {column_names}\n Natural Query: {natural_query}\n SQL Query: {sql_query}</s>"
    return {"content": content}

dataset = load_dataset(dataset_name)
dataset = dataset.map(
    preprocess,
    batched=False,
    remove_columns=dataset["train"].column_names
)
print(dataset)
print(dataset["train"][0])

In [None]:
print(dataset["train"][6]["content"])

In [None]:
print(len(dataset["train"]))

In [None]:
from datasets import DatasetDict

# Assuming `dataset` is your DatasetDict
dataset = dataset.rename_columns({"content": "text"})

# Verify the change
print(dataset)

### Create the PEFT model

### IA3 Config 

In [None]:
peft_config = IA3Config(target_modules=["k_proj", "v_proj", "down_proj"], 
                        feedforward_modules=["down_proj"], 
                        task_type=TaskType.CAUSAL_LM)

In [None]:
response_template = "SQL Query:"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = 0
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
model = AutoModelForCausalLM.from_pretrained(model_name)

# cast non-trainable params in bf16
for p in model.parameters():
    if not p.requires_grad:
        p.data = p.to(torch.float16)

## Training 

In [None]:
output_dir = "Ministral_3b_sql_instruct"
per_device_train_batch_size = 8
per_device_eval_batch_size = 8
gradient_accumulation_steps = 4
logging_steps = 5
learning_rate = 5e-4
max_grad_norm = 1.0
num_train_epochs=1
warmup_ratio = 0.1
lr_scheduler_type = "cosine"
max_seq_length = 256

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_strategy="no",
    evaluation_strategy="epoch",
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    max_grad_norm=max_grad_norm,
    weight_decay=0.1,
    warmup_ratio=warmup_ratio,
    lr_scheduler_type=lr_scheduler_type,
    fp16=True,
    report_to=["tensorboard", "wandb"],
    hub_private_repo=True,
    push_to_hub=True,
    num_train_epochs=num_train_epochs,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}
)