In [1]:
from dataclasses import dataclass, field
from typing import Optional

import torch

from transformers import AutoTokenizer, HfArgumentParser, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from accelerate import Accelerator

In [2]:
@dataclass
class ScriptArguments:
    """
    Arguments for the fine_tuning
    """
    per_device_train_batch_size: Optional[int] = field(default=2)
    per_device_eval_batch_size: Optional[int] = field(default=1)
    gradient_accumulation_steps: Optional[int] = field(default=4)
    learning_rate: Optional[float] = field(default=2e-4)
    max_grad_norm: Optional[float] = field(default=0.3)
    weight_decay: Optional[int] = field(default=0.001)
    lora_alpha= 64,
    lora_dropout =  0.5,
    lora_r = 32
    max_seq_length: Optional[int] = field(default=8192)
    model_name = "google/gemma-2b"
    fp16 = False
    bf16 = False
    gradient_checkpointing: Optional[bool] = field(
        default=True,
        metadata={"help": "Enables gradient checkpointing."},
    )
    use_flash_attention_2: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables Flash Attention 2."},
    )
    optim: Optional[str] = field(
        default="paged_adamw_32bit",
        metadata={"help": "The optimizer to use."},
    )
    lr_scheduler_type: str = field(
        default="constant",
        metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
    )
    max_steps: int = field(default=1000, metadata={"help": "How many optimizer update steps to take"}),
    epochs : int = field(default=5, metadata={"help": "How many epochs to train for"})
    warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
    save_steps: int = field(default=100, metadata={"help": "Save checkpoint every X updates steps."})
    logging_steps: int = field(default=200, metadata={"help": "Log every X updates steps."})
    output_dir: str = field(
        default="./gemma/results",
        metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
    )
    logging_dir: str = field(
        default="./gemma-2b/logs",
        metadata={"help": "The output directory where the logs will be written."},
    )
    eval_steps: int = field(default=200, metadata={"help": "How often to evaluate the model"})

parser = HfArgumentParser(ScriptArguments)
# Parse the arguments, ignoring unrecognized ones
script_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)


In [3]:
# Load the GG model - this is the local one, update it to the one on the Hub
model_id = "google/gemma-2b"
access_token = "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"

quantization_config = BitsAndBytesConfig(
     load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

In [4]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map ="auto",
    attn_implementation="eager"
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [6]:
#Lora config
lora_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

In [7]:
# Load dataset
train_dataset_url = "./small_dataset/train.jsonl"
test_dataset_url ="./small_dataset/test.jsonl"
validation_dataset_url ="./small_dataset/test.jsonl"

data_files = {
    'train': train_dataset_url,
    'test': test_dataset_url,
    'validation': validation_dataset_url
}

dataset = load_dataset('json', data_files=data_files)
train_dataset = dataset['train']
test_dataset = dataset['test']
validation_dataset = dataset['validation']

In [8]:
# Tokenize the data
def tokenize_function(examples):
    inputs = examples['input']
    targets = examples['output']
    max_length = script_args.max_seq_length
    model_input = tokenizer(inputs, max_length=2500, padding="max_length", truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=2500, padding="max_length", truncation=True)

    model_input['labels'] = labels['input_ids']
    return model_input

trained_data = train_dataset.map(tokenize_function, batched=True)
validation_data = validation_dataset.map(tokenize_function, batched=True)
test_data = test_dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/77 [00:00<?, ? examples/s]



In [10]:
accelerator = Accelerator()
model = accelerator.prepare_model(model)

In [11]:
sft_config = SFTConfig(
    output_dir=script_args.output_dir,
    per_device_train_batch_size=script_args.per_device_train_batch_size,
    per_device_eval_batch_size=script_args.gradient_accumulation_steps,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    save_steps=script_args.save_steps,
    logging_steps=script_args.logging_steps,
    optim=script_args.optim,
    num_train_epochs=script_args.epochs,
    lr_scheduler_type=script_args.lr_scheduler_type,
    gradient_checkpointing=script_args.gradient_checkpointing,
    eval_strategy="steps",
    eval_steps=script_args.eval_steps,
    logging_dir=script_args.logging_dir,
    warmup_ratio=script_args.warmup_ratio,
    logging_strategy="steps",
    learning_rate=script_args.learning_rate,
    max_seq_length= script_args.max_seq_length,
    fp16=script_args.fp16,
    bf16=script_args.bf16,
)

Evaluation metrics

In [12]:
from evaluate import load
import numpy as np

perplexity = load("perplexity", module_type="metric")
def compute_metrics(eval_pred):
    metrics, labels = eval_pred
    predictions = np.argmax(metrics, axis=-1)

    return perplexity.compute(predictions=predictions, model_id='gemma-2b')

In [None]:
#train
trainer = SFTTrainer(
    model=model,
    train_dataset=trained_data,
    eval_dataset=validation_data,
    peft_config=lora_config,
    tokenizer=tokenizer,
    args=sft_config,
    compute_metrics=compute_metrics,
    max_seq_length=2500,
)

trainer.train()


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
