In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig

In [None]:
from datasets import load_dataset

#Prepare dataset
def format_prompts(examples):
    """
    Define the format for dataset
    should return a dictionary with a "text" key containig the formatted prompts
    """
    pass

dataset = json.load("easy_train_data.json")
dataset=dataset.map(format_prompts, batched = True)

In [None]:
#Set up the model and tokenizer
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", quantization_config=bnb_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
model = prepare_model_for_kbit_training(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(model.device)

In [None]:
#Set up PEFT
peft_config = LoraConfig(
    task_type = "CASUAL_LM", 
    inference_mode = False,
    r=32, 
    lora_alpha =64, 
    target_modules =[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    lora_dropout = 0.1
    ) #Look into what's available

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [None]:
#Set up the training arguments
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir = "mistral_Lora",
    num_train_epochs=4,
    per_device_train_batch_size = 16,
    learniing_rate = 1e-5,
    optim="sgd"
)

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model = model,
    args=args,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=1024
)

trainer.train()