In [1]:
!pip install torch==2.7.1+cu126 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126

# Cài bitsandbytes
!pip install -U bitsandbytes

# Cài transformers, trl, peft, accelerate, datasets
!pip install transformers
!pip install trl
!pip install peft
!pip install accelerate
!pip install datasets
!pip install safetensors

Looking in indexes: https://download.pytorch.org/whl/cu126
[31mERROR: Operation cancelled by user[0m[31m
[0m^C


In [None]:
import trl as tt
print(tt.__version__)


In [2]:
import torch, transformers, trl, peft

print("torch:", torch.__version__)
print("transformers:", transformers.__version__)
print("trl:", trl.__version__)
print("peft:", peft.__version__)
print("CUDA:", torch.cuda.is_available())


torch: 2.7.1+cu126
transformers: 4.53.0
trl: 0.19.0
peft: 0.15.2
CUDA: True


In [1]:
import os
import gc
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    HfArgumentParser, TrainingArguments, pipeline, logging
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import json

input_file = 'medquad_dataset.jsonl'
output_file = 'medquad_llama_chat_format.jsonl'

with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
    for line in infile:
        item = json.loads(line)
        question = item['question'].strip().replace('\n', ' ')
        answer = item['answer'].strip().replace('\n', ' ')
        llama_chat_format = f"<s>[INST] {question} [/INST] {answer} </s>"
        outfile.write(json.dumps({"text": llama_chat_format}) + '\n')


In [None]:
import json
import random

# Đọc dữ liệu gốc
with open('medquad_llama_chat_format.jsonl', 'r') as f:
    data = [json.loads(line) for line in f]

# Xáo trộn dữ liệu
random.seed(42)
random.shuffle(data)

# Chia theo tỉ lệ 80% train, 10% valid, 10% test
n_total = len(data)
n_train = int(0.9 * n_total)

train_data = data[:n_train]
valid_data = data[n_train:]

# Ghi ra file
with open('train.jsonl', 'w') as f:
    for item in train_data:
        f.write(json.dumps(item) + '\n')

with open('valid.jsonl', 'w') as f:
    for item in valid_data:
        f.write(json.dumps(item) + '\n')


In [2]:
# base_model_name = 'NousResearch/Llama-2-7b-chat-hf'
# finetune_model_name='my_model_finetune_llama2_7b'

# output_dir = './results'

# no_of_epochs = 1

# # No change params
# use_4bit, bnb_4bit_compute_dtype, bnb_4bit_quant_type, use_nested_quant = True, "float16", "nf4", True # To quantization
# lora_r, lora_alpha, lora_dropout = 32, 8, 0.1
# fp16, bf16 =  False, False
# per_device_train_batch_size, per_device_eval_batch_size = 1, 1
# gradient_accumulation_steps, gradient_checkpointing, max_grad_norm = 1, True, 0.3
# learning_rate, weight_decay, optim = 2e-4, 0.001, "paged_adamw_32bit"
# lr_scheduler_type, max_steps, warmup_ratio = "cosine", -1, 0.03
# group_by_length, save_steps, logging_steps = True, 0, 25
# max_seq_length, packing, device_map = None, False, {"": 0}


base_model_name = 'NousResearch/Llama-2-7b-chat-hf'
finetune_model_name='my_model_finetune_llama2_7b'

output_dir = './results'

no_of_epochs = 1


# No change params
use_4bit, bnb_4bit_compute_dtype, bnb_4bit_quant_type, use_nested_quant = True, "int4", "nf4", True # To quantization
lora_r, lora_alpha, lora_dropout = 32, 8, 0.1
fp16, bf16 =  False, False
per_device_train_batch_size, per_device_eval_batch_size = 1, 1
gradient_accumulation_steps, gradient_checkpointing, max_grad_norm = 1, True, 0.3
learning_rate, weight_decay, optim = 2e-4, 0.001, "paged_adamw_32bit"
lr_scheduler_type, max_steps, warmup_ratio = "cosine", -1, 0.03
group_by_length, save_steps, logging_steps = True, 0, 25
max_seq_length, packing, device_map = None, False, {"": 0}




In [3]:
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map=device_map
)

model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.01s/it]


In [4]:
# Cấu hình LoRA
peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
)

# Set training parameters
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs= no_of_epochs, 
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    save_total_limit=2,
    
    remove_unused_columns=False,
    report_to="none"
)


In [5]:
# trainer = SFTTrainer(
#     model=model,
#     train_dataset=load_dataset('json', data_files='train.jsonl', split='train'),
#     peft_config=peft_config,
#     dataset_text_field="text",
#     max_seq_length=max_seq_length,
#     tokenizer=tokenizer,
#     args=training_arguments,
#     packing=packing,
# )

# trainer.train()


from datasets import load_dataset

# formatting function nếu bạn cần xử lý đầu vào
def formatting_func(example):
    return example["text"]  # hoặc "prompt" tuỳ theo bạn

# load dataset
dataset = load_dataset("json", data_files="train.jsonl", split="train")

def tokenize(example):
    text = formatting_func(example)
    return tokenizer(text, truncation=True, padding=False)

tokenized_dataset = dataset.map(tokenize)

sft_config = SFTConfig(
    max_seq_length=max_seq_length, # max_seq_length cũng nên được đặt ở đây
    packing=True
)
# init trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset, # Truyền dataset gốc vào đây
    processing_class=tokenizer, # Sử dụng processing_class thay cho tokenizer
    args=training_arguments,
     # để đảm bảo không ghép input
    formatting_func=formatting_func, # SFTTrainer sẽ sử dụng hàm này để định dạng text trước khi tokenize
    peft_config=peft_config# Đảm bảo max_seq_length được định nghĩa và truyền vào
)

trainer.train()



No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
25,0.9229
50,1.217
75,0.7853
100,0.9538
125,0.7005
150,0.729
175,0.6667
200,0.6709
225,0.5695
250,0.5766


TrainOutput(global_step=7383, training_loss=0.479064970825094, metrics={'train_runtime': 2978.4834, 'train_samples_per_second': 4.958, 'train_steps_per_second': 2.479, 'total_flos': 1.8974195302740787e+17, 'train_loss': 0.479064970825094})

In [9]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=4096, out_features=32, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=32, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): lora.Linear4bit(
            (base_layer): Linear4bit(in

In [10]:
trainer.save_model("./my_finetuned_model")
tokenizer.save_pretrained("./my_finetuned_model")


('./my_finetuned_model/tokenizer_config.json',
 './my_finetuned_model/special_tokens_map.json',
 './my_finetuned_model/tokenizer.json')

In [1]:
from transformers import pipeline

pipe = pipeline("text-generation", model="./my_finetuned_model", tokenizer="./my_finetuned_model")

output = pipe("What causes Heart Failure ?", max_new_tokens=512, do_sample=True, temperature=0.7)
print(output[0]['generated_text'])


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 103.63it/s]
Device set to use cuda:0


What causes Heart Failure ? [/INST] Heart failure is caused by damage to the heart muscle, which makes it harder for the heart to pump blood. The heart muscle can become damaged due to a variety of factors, including:                  - Heart attack or coronary artery disease  - Heart valve problems  - Heart muscle disease  - Heart failure caused by other conditions, such as diabetes, high blood pressure, or a virus  - Heart failure caused by a congenital heart defect  - Heart failure caused by a heart transplant                  Heart failure can also be caused by a condition called cardiomyopathy. Cardiomyopathy is a disease that affects the heart muscle. It can cause the heart to become enlarged, thin, or stiff.                  Heart failure can also be caused by a condition called pulmonary hypertension. Pulmonary hypertension is high blood pressure in the lungs. It can cause the heart to work harder to pump blood.                  Heart failure can also be caused by a condition c