# Notebook for finetuning Mistral 

## Load model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training


base_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    model_max_length=512,
    padding_side="left",
    add_eos_token=True)

tokenizer.pad_token = tokenizer.eos_token

## Data preparation

In [None]:
## create datasets 
# from .dataset_utils import get_dataset, get_instruct_dataset
# dataset = concatenate_datasets[
#     get_dataset(), 
#     get_instruct_dataset()
# ]

## load datasets 
from datasets import load_from_disk, concatenate_datasets
 

dataset = concatenate_datasets([
    load_from_disk('data/llms-405417/simple_text_med_dataset.hf'), 
    load_from_disk('data/llms-405417/instruct_med_dataset.hf')
])

dataset = dataset.shuffle()

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=128,
    lora_alpha=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)


In [None]:
if torch.cuda.device_count() > 1: # If more than 1 GPU
    model.is_parallelizable = True
    model.model_parallel = Tru

## Training

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments


output_dir = "./med_mistral"
tokenizer.pad_token = tokenizer.eos_token


trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        gradient_checkpointing=True,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-5,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 50,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to='none', # or log to WanDB 
        logging_dir="./logs",
    ),
)


model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

## Model interference

In [None]:
prompt = "[INST] {} [\INST]"

FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    prompt.format(
        "Continue the fibonnaci sequence. 1, 1, 2, 3, 5, 8 ..."
    ) 
    # prompt.format(
    #     "What are symptoms of flu?"
    # )
], return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
tokenizer.batch_decode(outputs)

## Push model to HF hub

In [None]:
# model.save_pretrained_merged("med_mistral", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("atadria/med_mistral", tokenizer, save_method = "merged_16bit", token = "")