In [None]:
!pip install --no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes

In [None]:
!pip install --no-deps xformers trl peft accelerate bitsandbytes


In [None]:
import json

import torch
from datasets import load_dataset
from huggingface_hub import notebook_login
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLanguageModel

In [None]:
notebook_login()

In [None]:
config = {
    "hugging_face_username":"Apurva3509",
    "model_config": {
        "base_model":"unsloth/llama-3-8b-Instruct-bnb-4bit", 
        "finetuned_model":"llama-3-8b-Instruct-bnb-4bit-medical",
        "max_seq_length": 2048, 
        "dtype":torch.float16, 
        "load_in_4bit": True,
    },
    "lora_config": {
      "r": 16, 
      "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"], 
      "lora_alpha":16, 
      "lora_dropout":0, 
      "bias":"none", 
      "use_gradient_checkpointing":True,
      "use_rslora":False, 
      "use_dora":False, 
      "loftq_config":None 
    },
    "training_dataset":{
        "name":"Shekswess/medical_llama3_instruct_dataset_short",
        "split":"train", 
        "input_field":"prompt", 
    },
    "training_config": {
        "per_device_train_batch_size": 2, 
        "gradient_accumulation_steps": 4, 
        "warmup_steps": 5,
        "max_steps":0, 
        "num_train_epochs": 1,  
        "learning_rate": 2e-4, 
        "fp16": not torch.cuda.is_bf16_supported(),  
        "bf16": torch.cuda.is_bf16_supported(), 
        "logging_steps": 1,  
        "optim" :"adamw_8bit",  
        "weight_decay" : 0.01,   
        "lr_scheduler_type": "linear",  
        "seed" : 42, 
        "output_dir" : "outputs",  
    }
}

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = config.get("model_config").get("base_model"),
    max_seq_length = config.get("model_config").get("max_seq_length"),
    dtype = config.get("model_config").get("dtype"),
    load_in_4bit = config.get("model_config").get("load_in_4bit"),
)

model = FastLanguageModel.get_peft_model(
    model,
    r = config.get("lora_config").get("r"),
    target_modules = config.get("lora_config").get("target_modules"),
    lora_alpha = config.get("lora_config").get("lora_alpha"),
    lora_dropout = config.get("lora_config").get("lora_dropout"),
    bias = config.get("lora_config").get("bias"),
    use_gradient_checkpointing = config.get("lora_config").get("use_gradient_checkpointing"),
    random_state = 42,
    use_rslora = config.get("lora_config").get("use_rslora"),
    use_dora = config.get("lora_config").get("use_dora"),
    loftq_config = config.get("lora_config").get("loftq_config"),
)

dataset_train = load_dataset(config.get("training_dataset").get("name"), split = config.get("training_dataset").get("split"))

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset_train,
    dataset_text_field = config.get("training_dataset").get("input_field"),
    max_seq_length = config.get("model_config").get("max_seq_length"),
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = config.get("training_config").get("per_device_train_batch_size"),
        gradient_accumulation_steps = config.get("training_config").get("gradient_accumulation_steps"),
        warmup_steps = config.get("training_config").get("warmup_steps"),
        max_steps = config.get("training_config").get("max_steps"),
        num_train_epochs= config.get("training_config").get("num_train_epochs"),
        learning_rate = config.get("training_config").get("learning_rate"),
        fp16 = config.get("training_config").get("fp16"),
        bf16 = config.get("training_config").get("bf16"),
        logging_steps = config.get("training_config").get("logging_steps"),
        optim = config.get("training_config").get("optim"),
        weight_decay = config.get("training_config").get("weight_decay"),
        lr_scheduler_type = config.get("training_config").get("lr_scheduler_type"),
        seed = 42,
        output_dir = config.get("training_config").get("output_dir"),
    ),
)

7789b62606e1486702bb3700ba4ce6f8021bcfed

In [None]:
trainer_stats = trainer.train()

In [None]:
used_memory = round(torch.cuda.max_memory_allocated() / 1024**3, 2)
used_memory_lora = round(used_memory - reserved_memory, 2)
used_memory_persentage = round((used_memory / max_memory) * 100, 2)
used_memory_lora_persentage = round((used_memory_lora / max_memory) * 100, 2)
print(f"Used Memory: {used_memory}GB ({used_memory_persentage}%)")
print(f"Used Memory for training(fine-tuning) LoRA: {used_memory_lora}GB ({used_memory_lora_persentage}%)")

In [None]:
with open("trainer_stats.json", "w") as f:
    json.dump(trainer_stats, f, indent=4)

In [None]:
model.save_pretrained(config.get("model_config").get("finetuned_model"))
model.push_to_hub(config.get("model_config").get("finetuned_model"), tokenizer = tokenizer)

In [None]:
model.save_pretrained_merged(config.get("model_config").get("finetuned_model"), tokenizer, save_method = "merged_4bit_forced",)
model.push_to_hub_merged(config.get("model_config").get("finetuned_model"), tokenizer, save_method = "merged_4bit_forced")

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = config.get("model_config").get("finetuned_model"),
        max_seq_length = config.get("model_config").get("max_seq_length"),
        dtype = config.get("model_config").get("dtype"),
        load_in_4bit = config.get("model_config").get("load_in_4bit"),
    )

FastLanguageModel.for_inference(model)

inputs = tokenizer(
[
    "<|start_header_id|>system<|end_header_id|> Answer the question truthfully, you are a medical professional.<|eot_id|><|start_header_id|>user<|end_header_id|> This is the question: Can you provide an overview of the lung's squamous cell carcinoma?<|eot_id|>"
], return_tensors = "pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens = 256, use_cache = True)
tokenizer.batch_decode(outputs, skip_special_tokens = True)