In [8]:
import os
import sys
import logging
from typing import List, Union
import torch
from dataclasses import dataclass, field
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    HfArgumentParser,
    BitsAndBytesConfig,
)
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_from_disk, disable_caching
disable_caching()

#! Wandb Project Name
os.environ["WANDB_PROJECT"] = "Text2SQL"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [4]:
def prepare_dataset_for_training(dataset, tokenizer, prompt_file):
    
    with open(prompt_file, "r") as f:
        prompt = f.read()
    columns = dataset["train"].features.keys()

    def preprocess_function(sample):
        sample["text"] = prompt.format(
            user_question=sample["question"],
            table_metadata_string=sample["context"],
            sql=(
                sample["answer"]
                if sample["answer"].endswith(";")
                else sample["answer"] + ";"
            ),
            eos_token=tokenizer.eos_token
        ).strip()

        return sample

    train_dataset = dataset.map(
        preprocess_function,
        remove_columns=columns,
    )
    return train_dataset

In [5]:
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
if not tokenizer.pad_token:
    tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
print(tokenizer.special_tokens_map)

{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<|pad|>', 'additional_special_tokens': ['▁<PRE>', '▁<MID>', '▁<SUF>', '▁<EOT>']}


In [7]:
dataset = load_from_disk("./datasets/sql-create-context-split")
train_dataset = prepare_dataset_for_training(dataset, tokenizer, prompt_file="./prompts/prompt_v2_train.md")
train_dataset["train"]["text"][0]

Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

'### Task\nGenerate a SQL query to answer [QUESTION]Which type of policy is most frequently used? Give me the policy type code.[/QUESTION]\n\n### Instructions\n- End the SQL query with ";"\n- Do not explain the Answer SQL\n\n### Database Schema\nThe query will run on a database with the following schema:\nCREATE TABLE policies (policy_type_code VARCHAR)\n\n### Answer\nGiven the database schema, here is the SQL query that answers [QUESTION]Which type of policy is most frequently used? Give me the policy type code.[/QUESTION]\n[SQL]SELECT policy_type_code FROM policies GROUP BY policy_type_code ORDER BY COUNT(*) DESC LIMIT 1;[/SQL]\n</s>'

In [23]:
@dataclass
class ModelConfig:
    model: str = field(default="codellama/CodeLlama-7b-Instruct-hf")
    dataset: str = field(default="./datasets/sql-create-context-split")
    prompt: str = field(default="./prompts/prompt_v2_train.md")
    max_seq_length: int = field(default=1024)
    bits: int = field(default=4)
    bnb_4bit_quant_type: str = field(default="nf4")
    r: int = field(default=16)
    lora_alpha: int = field(default=32)
    lora_dropout: float = field(default=0.1)
    target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
    bias: str = field(default="none")
    init_lora_weights: Union[bool, str] = field(default=True)
    task_type: str = field(default="CAUSAL_LM")

In [24]:
parser = HfArgumentParser((ModelConfig, TrainingArguments))
model_config, training_args = parser.parse_json_file(json_file="./configs/codellama-v1.json")

In [7]:
torch_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
torch_dtype

torch.bfloat16

In [8]:
model = AutoModelForCausalLM.from_pretrained(
    model_config.model,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_8bit=model_config.bits == 8,
        load_in_4bit=model_config.bits == 4,
        bnb_4bit_compute_dtype=torch_dtype,
        bnb_4bit_quant_type=model_config.bnb_4bit_quant_type,
    ),
    torch_dtype=torch_dtype,
    trust_remote_code=True,
)
model.resize_token_embeddings(len(tokenizer))

INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Embedding(32017, 4096)

In [9]:
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)

In [10]:
lora_config = LoraConfig(
    r=model_config.r,
    lora_alpha=model_config.lora_alpha,
    target_modules=model_config.target_modules,
    lora_dropout=model_config.lora_dropout,
    bias=model_config.bias,
    init_lora_weights=model_config.init_lora_weights,
    task_type=model_config.task_type,
)
print(lora_config)
model = get_peft_model(model, lora_config)

LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type='CAUSAL_LM', inference_mode=False, r=128, target_modules='all-linear', lora_alpha=128, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False))


In [11]:
model.print_trainable_parameters()

trainable params: 319,815,680 || all params: 7,058,370,560 || trainable%: 4.5310


In [15]:
response_template = "[SQL]"
data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

In [None]:
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset["train"],
    eval_dataset=train_dataset["val"],
    data_collator=data_collator,
    max_seq_length=model_config.max_seq_length
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

Map:   0%|          | 0/250 [00:00<?, ? examples/s]



In [None]:
trainer.train()

In [None]:
trainer.save_model(training_args.output_dir)