Here is where we do the fine tuning of the model. We will use the `transformers` library to load the pre-trained model and tokenizer. 

This is configured to run on a loacl machine with a GPU. In this case a single NVIDIA RTX 3090. 

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import bitsandbytes
import accelerate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, AutoModelForCausalLM, LlamaForCausalLM
from peft import get_peft_model, LoraConfig



In [2]:
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

0
<torch.cuda.device object at 0x77074ca46950>
1
NVIDIA GeForce RTX 3090


In [3]:
max_input_length = 2048
# Load the dataset
dataset = load_dataset("patrickjmcbride/math-instruct-dataset")

# Trim the dataset to 2000 entries for training and 200 entries for testing
train_dataset = dataset['train'].select(range(2000))
test_dataset = dataset['test'].select(range(200))

# Load the tokenizer and add a special pad token
tokenizer = AutoTokenizer.from_pretrained("failspy/Meta-Llama-3-8B-Instruct-abliterated-v3")
tokenizer.pad_token = tokenizer.eos_token

def generate_prompt(entry):
    return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{entry["instruction"]}
### Context:
{entry["context"]}
### Response:
{entry["output"]}"""
 
def tokenize(prompt, add_eos_token=True, max_length=2048):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=max_length,
        padding='max_length',  # Add padding
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < max_length
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)
 
    result["labels"] = result["input_ids"].copy()
 
    return result
 
def preprocess_function(entry):
    full_prompt = generate_prompt(entry)
    tokenized_full_prompt = tokenize(prompt=full_prompt, max_length=max_input_length)
    return tokenized_full_prompt

# Apply the preprocessing function and filter the dataset
tokenized_train_dataset = train_dataset.map(preprocess_function)
tokenized_test_dataset = test_dataset.map(preprocess_function)


# Load the model and resize embeddings for the new special token
model = LlamaForCausalLM.from_pretrained("failspy/Meta-Llama-3-8B-Instruct-abliterated-v3", load_in_8bit=True, torch_dtype=torch.float16)
model.resize_token_embeddings(len(tokenizer))

def create_peft_config(model):
    from peft import (
        get_peft_model,
        LoraConfig,
        TaskType,
        prepare_model_for_kbit_training,
    )

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules = ["q_proj", "v_proj"]
    )

    # prepare int-8 model for training
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, peft_config)
    return model, peft_config

# create peft config
model, lora_config = create_peft_config(model)


# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    per_device_train_batch_size=1,
    fp16=True,  # Enable mixed precision training
    gradient_checkpointing=True,  # Enable gradient checkpointing
    gradient_accumulation_steps=16,  # Accumulate gradients over 16 steps
    save_total_limit=3,
    save_steps=100,
    eval_steps=100,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
)

# Print the shapes of the input IDs and attention mask
for i, data in enumerate(tokenized_train_dataset):
    # make sure they are all padded to max_input_length
    assert len(data["input_ids"]) == max_input_length
    assert len(data["attention_mask"]) == max_input_length

# Train the model
trainer.train()

# Save the fine-tuned model
model.save_pretrained("./Meta-Llama-3-8B-Instruct-abliterated-math-v0")
tokenizer.save_pretrained("./Meta-Llama-3-8B-Instruct-abliterated-math-v0")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])
Input IDs shape: torch.Size([2048])
Attention Mask shape: torch.Size([2048])

  0%|          | 0/375 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


{'loss': 4.7726, 'grad_norm': 14.51529598236084, 'learning_rate': 1.9573333333333335e-05, 'epoch': 0.08}
{'loss': 3.8359, 'grad_norm': 10.308284759521484, 'learning_rate': 1.904e-05, 'epoch': 0.16}
{'loss': 2.701, 'grad_norm': 17.57379150390625, 'learning_rate': 1.8560000000000002e-05, 'epoch': 0.24}
{'loss': 1.5154, 'grad_norm': 6.312963485717773, 'learning_rate': 1.8026666666666668e-05, 'epoch': 0.32}
{'loss': 1.0415, 'grad_norm': 0.7911285161972046, 'learning_rate': 1.7493333333333334e-05, 'epoch': 0.4}
{'loss': 1.0092, 'grad_norm': 0.46872249245643616, 'learning_rate': 1.696e-05, 'epoch': 0.48}
{'loss': 0.9861, 'grad_norm': 0.3881501257419586, 'learning_rate': 1.642666666666667e-05, 'epoch': 0.56}
{'loss': 0.9624, 'grad_norm': 0.3677951991558075, 'learning_rate': 1.5893333333333333e-05, 'epoch': 0.64}
