Here is where we do the fine tuning of the model. We will use the `transformers` library to load the pre-trained model and tokenizer. 

This is configured to run on a loacl machine with a GPU. In this case a single NVIDIA RTX 3090. 

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import bitsandbytes
import accelerate
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, AutoModelForCausalLM, LlamaForCausalLM
from peft import get_peft_model, LoraConfig



In [None]:
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))

In [None]:
max_input_length = 1024
# Load the dataset
dataset = load_from_disk("MathInstructSmall")

# Split the dataset into training and testing sets
dataset = dataset.train_test_split(test_size=0.5)

train_dataset = dataset['train'].select(range(5000))
test_dataset = dataset['test'].select(range(500))

# Load the tokenizer and add a special pad token
tokenizer = AutoTokenizer.from_pretrained("failspy/Meta-Llama-3-8B-Instruct-abliterated-v3")
tokenizer.pad_token = tokenizer.eos_token
 
def tokenize(prompt=None, add_eos_token=True, max_length=None):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=max_length,
        padding='max_length',  # Add padding
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < max_length
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)
 
    result["labels"] = result["input_ids"].copy()
 
    return result
 
def preprocess_function(entry):
    full_prompt = entry['text']
    tokenized_full_prompt = tokenize(prompt=full_prompt, max_length=max_input_length, add_eos_token=True)
    return tokenized_full_prompt

# Apply the preprocessing function and filter the dataset
tokenized_train_dataset = train_dataset.map(preprocess_function)
tokenized_test_dataset = test_dataset.map(preprocess_function)


# Load the model and resize embeddings for the new special token
model = LlamaForCausalLM.from_pretrained("failspy/Meta-Llama-3-8B-Instruct-abliterated-v3", load_in_8bit=True, torch_dtype=torch.float16)
model.resize_token_embeddings(len(tokenizer))

def create_peft_config(model):
    from peft import (
        get_peft_model,
        LoraConfig,
        TaskType,
        prepare_model_for_kbit_training,
    )

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules = ["q_proj", "v_proj"]
    )

    # prepare int-8 model for training
    model = prepare_model_for_kbit_training(model)
    model = get_peft_model(model, peft_config)
    return model, peft_config

# create peft config
model, lora_config = create_peft_config(model)


# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    per_device_train_batch_size=1,
    fp16=True,  # Enable mixed precision training
    gradient_checkpointing=True,  # Enable gradient checkpointing
    gradient_accumulation_steps=16,  # Accumulate gradients over 16 steps
    save_total_limit=3,
    save_steps=100,
    eval_steps=100,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_test_dataset,
)

# Print the shapes of the input IDs and attention mask
for i, data in enumerate(tokenized_train_dataset):
    # make sure they are all padded to max_input_length
    assert len(data["input_ids"]) == max_input_length
    assert len(data["attention_mask"]) == max_input_length

# Train the model
trainer.train()

# Save the fine-tuned model
model.save_pretrained("./Meta-Llama-3-8B-Instruct-abliterated-math-v0")
tokenizer.save_pretrained("./Meta-Llama-3-8B-Instruct-abliterated-math-v0")