# Fine-tuning the Gemma Base model

In [1]:
from dataclasses import dataclass, field
from typing import Optional

import torch

from transformers import AutoTokenizer, HfArgumentParser, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from accelerate import Accelerator
import os

In [None]:
from huggingface_hub import login
login()

In [2]:
@dataclass
class ScriptArguments:
    """
    Arguments for the fine_tuning
    """
    base_model = "google/gemma-2b" 
    fine_tuned_model = "gemma-2-2b-software-model_completion_finetuned"
    merged_model = "gemma-2-2b-software-model_completion"
    dataset_name = "/home/ubuntu/dataset/structural_removal_non_contiguous"
    per_device_train_batch_size: Optional[int] = field(default=1)
    per_device_eval_batch_size: Optional[int] = field(default=1)
    gradient_accumulation_steps: Optional[int] = field(default=4)
    evaluation_strategy: Optional[str] = field(default="steps")
    evaluation_accumulation_steps: Optional[int] = field(default=5)
    learning_rate: Optional[float] = field(default=2e-4)
    max_grad_norm: Optional[float] = field(default=0.3)
    weight_decay: Optional[int] = field(default=0.001)
    lora_alpha= 25,
    lora_dropout =  0.5,
    lora_r = 16
    max_seq_length: Optional[int] = field(default=4100)
    fp16 = True
    bf16 = False
    gradient_checkpointing: Optional[bool] = field(
        default=True,
        metadata={"help": "Enables gradient checkpointing."},
    )
    use_flash_attention_2: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables Flash Attention 2."},
    )
    optim: Optional[str] = field(
        default="paged_adamw_32bit",
        metadata={"help": "The optimizer to use."},
    )
    lr_scheduler_type: str = field(
        default="constant",
        metadata={"help": "Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis"},
    )
    max_steps: int = field(default=100, metadata={"help": "How many optimizer update steps to take"}),
    epochs : int = field(default=3, metadata={"help": "How many epochs to train for"})
    warmup_ratio: float = field(default=0.03, metadata={"help": "Fraction of steps to do a warmup for"})
    save_steps: int = field(default=87, metadata={"help": "Save checkpoint every X updates steps."})
    logging_steps: int = field(default=87, metadata={"help": "Log every X updates steps."})
    output_dir: str = field(
        default="./gemma2b/results",
        metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
    )
    logging_dir: str = field(
        default="./gemma-2b/logs",
        metadata={"help": "The output directory where the logs will be written."},
    )
    eval_steps: int = field(default=87, metadata={"help": "How often to evaluate the model"})

parser = HfArgumentParser(ScriptArguments)
# Parse the arguments, ignoring unrecognized ones
script_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)


In [3]:
# Load the GG model - this is the local one, update it to the one on the Hub
access_token = "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"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

In [4]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    script_args.base_model,
    quantization_config=quantization_config,
    device_map ="auto",
    attn_implementation="eager"
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model)
tokenizer.pad_token = tokenizer.eos_token

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [6]:
import bitsandbytes as bnb

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:  # needed for 16 bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(model)
print(modules)

['o_proj', 'up_proj', 'q_proj', 'v_proj', 'down_proj', 'gate_proj', 'k_proj']


In [7]:
#Lora config
lora_config = LoraConfig(
    r=16,
    lora_alpha=25,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=modules
)

In [8]:
# Load dataset
abs_path = script_args.dataset_name
dataset_to_use = "processed_4000"
train_dataset_url = f"{abs_path}/{dataset_to_use}/train.jsonl"
test_dataset_url = f"{abs_path}/{dataset_to_use}/test.jsonl"
validation_dataset_url = f"{abs_path}/{dataset_to_use}/validation.jsonl"

data_files = {
    'train': train_dataset_url,
    'test': test_dataset_url,
    'validation': validation_dataset_url
}

dataset = load_dataset('json', data_files=data_files)
train_dataset = dataset['train']
test_dataset = dataset['test']
validation_dataset = dataset['validation']

In [10]:
# Tokenize the data
def tokenize_function(examples):
    inputs = [inp for inp in examples['input']]
    targets = examples['output']
    max_length = script_args.max_seq_length
    model_input = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_length, padding="max_length", truncation=True)

    model_input['labels'] = labels['input_ids']
    return model_input

trained_data = train_dataset.map(tokenize_function, batched=True)
validation_data = validation_dataset.map(tokenize_function, batched=True)
test_data = test_dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/580 [00:00<?, ? examples/s]

In [11]:
accelerator = Accelerator()
model = accelerator.prepare_model(model)

In [12]:
sft_config = SFTConfig(
    output_dir=script_args.output_dir,
    per_device_train_batch_size=script_args.per_device_train_batch_size,
    per_device_eval_batch_size=script_args.per_device_eval_batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    save_steps=script_args.save_steps,
    logging_steps=script_args.logging_steps,
    optim=script_args.optim,
    num_train_epochs=script_args.epochs,
    lr_scheduler_type=script_args.lr_scheduler_type,
    gradient_checkpointing=script_args.gradient_checkpointing,
    eval_strategy=script_args.evaluation_strategy,
    eval_steps=script_args.eval_steps,
    eval_accumulation_steps=script_args.evaluation_accumulation_steps,
    logging_dir=script_args.logging_dir,
    warmup_ratio=script_args.warmup_ratio,
    logging_strategy="steps",
    learning_rate=script_args.learning_rate,
    max_seq_length= script_args.max_seq_length,
    fp16=script_args.fp16,
    bf16=script_args.bf16,

)

In [13]:
type(trained_data)

datasets.arrow_dataset.Dataset

In [None]:
from datasets import Dataset

# Assuming `original_dataset` is your Dataset object
first_element = trained_data[0]

# Convert the first element into a new Dataset object
new_dataset_train = Dataset.from_dict({key: [value] for key, value in first_element.items()})

# Assuming `original_dataset` is your Dataset object
first_element = validation_data[0]

# Convert the first element into a new Dataset object
new_dataset_validation = Dataset.from_dict({key: [value] for key, value in first_element.items()})


In [None]:
'''from evaluate import load
import numpy as np

perplexity = load("perplexity", module_type="metric")
def compute_metrics(eval_pred):
    metrics, labels = eval_pred
    predictions = np.argmax(metrics, axis=-1)

    return perplexity.compute(predictions=predictions, model_id='gemma-2b')'''
import evaluate
import numpy as np

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)


def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    print(type(logits))
    pred_ids = torch.argmax(logits, dim=-1)

    return pred_ids, labels

Evaluation metrics

In [2]:
#train
trainer = SFTTrainer(
    model=model,
    train_dataset=trained_data,
    eval_dataset=validation_data,
    peft_config=lora_config,
    #tokenizer=tokenizer,
    args=sft_config,
    max_seq_length=script_args.max_seq_length,
    #compute_metrics=compute_metrics,
    #preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

NameError: name 'SFTTrainer' is not defined

In [1]:
trainer.train()

NameError: name 'trainer' is not defined

### Saving the Model !

In [16]:
trainer.model.save_pretrained(script_args.fine_tuned_model)

In [17]:

# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model)

base_model_reload= AutoModelForCausalLM.from_pretrained(
    script_args.base_model,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="cpu",
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
from peft import PeftModel

#base_model_reload, tokenizer = setup_chat_format(base_model_reload, tokenizer)
model = PeftModel.from_pretrained(base_model_reload, script_args.fine_tuned_model)

model = model.merge_and_unload()

In [19]:
model.save_pretrained(script_args.merged_model)
tokenizer.save_pretrained(script_args.merged_model)

('gemma-2-2b-software-model_completion/tokenizer_config.json',
 'gemma-2-2b-software-model_completion/special_tokens_map.json',
 'gemma-2-2b-software-model_completion/tokenizer.json')