# Finetuning

This module contains some functions useful to finetune models.

In [None]:
#| default_exp finetune

## Implementation

In [None]:
#| export
from peft import PeftModel

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# These imports are only used in test
from llama_wrapper import import_llama
from peft import LoraConfig, get_peft_model
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
import torch
import os

In [None]:
#| export
def lora_model_zeros_and_scales_to_half(
        model: PeftModel # Original model
    ) -> PeftModel: # Converted model
    """
    Convert zeros and scales for PeftModel to half-precision
    """
    for _, m in model.named_modules():
        if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str(type(m)):
            if hasattr(m, "is_v1_model") and m.is_v1_model:
                m.zeros = m.zeros.half()
            m.scales = m.scales.half()
    return model

## Testing

To make a *very* simple test here I will just try to:

- see the perplexity of pretrained Vicuna model for the sample text file
- train the LoRA adapters on top of Vicuna
- see the perplexity of such a model

In [None]:
_, train_data, load_llama_model_4bit_low_ram, _, model_to_half, _, apply_gradient_checkpointing, _, AMPWrapper = import_llama(
    use_flash_attention=False,
    use_xformers=False,
    autograd_4bit_cuda=False,
    autograd_4bit_triton=True
)

Using Triton implementation.


### Pretrained state

In [None]:
if not os.path.exists("../vicuna-13b-GPTQ-4bit-128g"):
    !git clone "https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g"
    !mv "vicuna-13b-GPTQ-4bit-128g" ..

In [None]:
model_pretrained, tokenizer = load_llama_model_4bit_low_ram(
    config_path="../vicuna-13b-GPTQ-4bit-128g/",
    model_path="../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors",
    groupsize=128,
    is_v1_model=False,
)
tokenizer.pad_token_id = 0

Loading Model ...


The safetensors archive passed at ../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors does not contain metadata. Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata.


Loaded the model in 3.27 seconds.


In [None]:
dataset = train_data.TrainTxt(
    dataset="01_alpaca_text.txt",
    val_set_size=0,
    tokenizer=tokenizer,
    cutoff_len=256,
)
dataset.prepare_data(thd=-1, use_eos_token=1)

                                                  

Train Data: 0.00% outliers


In [None]:
model_to_half(model_pretrained)
model_pretrained_wrapper = AMPWrapper(model_pretrained)
model_pretrained_wrapper.apply_forward()

Converted as Half.


In [None]:
def _test_model(model, data):
    probabilities = []
    with torch.no_grad():
        for sample in data:
            input_ids = torch.LongTensor([sample["input_ids"]]).cuda()
            response = model.forward(input_ids, return_dict=True)
            logits = response['logits'][0]
            probas = torch.nn.functional.softmax(logits, dim=-1)
            proba = probas.max(dim=-1).values.mean().item()
            probabilities.append(proba)
    average_proba = sum(probabilities) / len(probabilities)
    return average_proba

In [None]:
model_pretrained.eval()
_test_model(model_pretrained, dataset.train_data)

0.703210663377193

In [None]:
model_pretrained.cpu()
torch.cuda.empty_cache()

### Finetune

In [None]:
model, tokenizer = load_llama_model_4bit_low_ram(
    config_path="../vicuna-13b-GPTQ-4bit-128g/",
    model_path="../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors",
    groupsize=128,
    is_v1_model=False,
)
tokenizer.pad_token_id = 0

Loading Model ...


The safetensors archive passed at ../vicuna-13b-GPTQ-4bit-128g/vicuna-13b-4bit-128g.safetensors does not contain metadata. Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata.


Loaded the model in 2.38 seconds.


In [None]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)
lora_model = get_peft_model(model, lora_config)
lora_model = lora_model_zeros_and_scales_to_half(lora_model)

In [None]:
apply_gradient_checkpointing(lora_model, checkpoint_ratio=1);

Forward Patch Applied For Block 0
Forward Patch Applied For Block 1
Forward Patch Applied For Block 2
Forward Patch Applied For Block 3
Forward Patch Applied For Block 4
Forward Patch Applied For Block 5
Forward Patch Applied For Block 6
Forward Patch Applied For Block 7
Forward Patch Applied For Block 8
Forward Patch Applied For Block 9
Forward Patch Applied For Block 10
Forward Patch Applied For Block 11
Forward Patch Applied For Block 12
Forward Patch Applied For Block 13
Forward Patch Applied For Block 14
Forward Patch Applied For Block 15
Forward Patch Applied For Block 16
Forward Patch Applied For Block 17
Forward Patch Applied For Block 18
Forward Patch Applied For Block 19
Forward Patch Applied For Block 20
Forward Patch Applied For Block 21
Forward Patch Applied For Block 22
Forward Patch Applied For Block 23
Forward Patch Applied For Block 24
Forward Patch Applied For Block 25
Forward Patch Applied For Block 26
Forward Patch Applied For Block 27
Forward Patch Applied For Bloc

In [None]:
training_arguments = TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    warmup_steps=5,
    optim="adamw_torch",
    num_train_epochs=10,
    learning_rate=3e-4,
    fp16=True,
    logging_steps=20,
    evaluation_strategy="no",
    save_strategy="steps",
    eval_steps=None,
    save_steps=50,
    output_dir="lora-output-directory",
    save_total_limit=3,
    load_best_model_at_end=False,
    ddp_find_unused_parameters=False,
    report_to="none",
)

In [None]:
trainer = Trainer(
    lora_model,
    train_dataset=dataset.train_data,
    eval_dataset=dataset.val_data,
    args=training_arguments,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
lora_model.config.use_cache = False

In [None]:
trainer.train()

                                                

{'loss': 3.5212, 'learning_rate': 0.0002930973451327433, 'epoch': 0.35}


                                                

{'loss': 2.2487, 'learning_rate': 0.0002824778761061947, 'epoch': 0.7}


                                                  

{'loss': 2.211, 'learning_rate': 0.00027185840707964596, 'epoch': 1.05}


                                                

{'loss': 2.0809, 'learning_rate': 0.00026123893805309734, 'epoch': 1.4}


                                                 

{'loss': 1.7463, 'learning_rate': 0.00025061946902654866, 'epoch': 1.75}


                                                   

{'loss': 1.8963, 'learning_rate': 0.00023999999999999998, 'epoch': 2.11}


                                                 

{'loss': 1.2548, 'learning_rate': 0.0002293805309734513, 'epoch': 2.46}


                                                   

{'loss': 1.8632, 'learning_rate': 0.00021929203539823008, 'epoch': 2.81}


                                                 

{'loss': 1.3105, 'learning_rate': 0.00020867256637168138, 'epoch': 3.16}


                                                 

{'loss': 1.1854, 'learning_rate': 0.00019805309734513272, 'epoch': 3.51}


                                                 

{'loss': 1.3491, 'learning_rate': 0.00018849557522123892, 'epoch': 3.86}


                                                 

{'loss': 1.6193, 'learning_rate': 0.00017787610619469026, 'epoch': 4.21}


                                                   

{'loss': 1.0474, 'learning_rate': 0.00016831858407079646, 'epoch': 4.56}


                                                 

{'loss': 1.2461, 'learning_rate': 0.00015769911504424775, 'epoch': 4.91}


                                                 

{'loss': 1.2838, 'learning_rate': 0.00014761061946902654, 'epoch': 5.26}


                                                   

{'loss': 1.0848, 'learning_rate': 0.00013699115044247788, 'epoch': 5.61}


                                                 

{'loss': 1.2285, 'learning_rate': 0.0001263716814159292, 'epoch': 5.96}


                                                   

{'loss': 1.0472, 'learning_rate': 0.00011575221238938052, 'epoch': 6.32}


                                                 

{'loss': 0.8704, 'learning_rate': 0.00010513274336283186, 'epoch': 6.67}


                                                 

{'loss': 1.475, 'learning_rate': 9.451327433628319e-05, 'epoch': 7.02}


                                                   

{'loss': 0.9795, 'learning_rate': 8.389380530973451e-05, 'epoch': 7.37}


                                                 

{'loss': 1.7936, 'learning_rate': 7.327433628318583e-05, 'epoch': 7.72}


                                                   

{'loss': 1.4427, 'learning_rate': 6.265486725663716e-05, 'epoch': 8.07}


                                                 

{'loss': 0.9338, 'learning_rate': 5.203539823008849e-05, 'epoch': 8.42}


                                                 

{'loss': 1.7045, 'learning_rate': 4.141592920353982e-05, 'epoch': 8.77}


                                                 

{'loss': 1.519, 'learning_rate': 3.0796460176991146e-05, 'epoch': 9.12}


                                                 

{'loss': 1.3084, 'learning_rate': 2.0176991150442476e-05, 'epoch': 9.47}


                                                 

{'loss': 1.5564, 'learning_rate': 9.557522123893805e-06, 'epoch': 9.82}


                                                 

{'train_runtime': 1549.3327, 'train_samples_per_second': 0.368, 'train_steps_per_second': 0.368, 'train_loss': 1.52357904032657, 'epoch': 10.0}


100%|██████████| 570/570 [25:49<00:00,  2.72s/it]


TrainOutput(global_step=570, training_loss=1.52357904032657, metrics={'train_runtime': 1549.3327, 'train_samples_per_second': 0.368, 'train_steps_per_second': 0.368, 'train_loss': 1.52357904032657, 'epoch': 10.0})

### Tuned model

In [None]:
lora_model.eval()
_test_model(lora_model, dataset.train_data)



0.8225835499010588