In [None]:
!pip install unsloth
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git@nightly git+https://github.com/unslothai/unsloth-zoo.git

In [None]:
import os
import torch
import pandas as pd
from unsloth import FastLanguageModel
from datasets import Dataset, load_dataset, concatenate_datasets
from trl import SFTTrainer
from transformers import TrainingArguments
from huggingface_hub import login

from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')
login(HF_TOKEN)


In [None]:
model_name = 'unsloth/Qwen3-4B-unsloth-bnb-4bit'
repo_name = 'VyDat/qwen3-4b-bnb-4bit'

data_path = 'VyDat/vsl-data'
data_augment_path = 'VyDat/Copus-Vie-VSL-10K'
pub_data_path = '5CD-AI/Vietnamese-Multi-turn-Chat-Alpaca'

max_seq_length=2048
dtype=None

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=64,
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing=True,
    random_state=42,
)

In [None]:
def format_data_hf(data_path):
    print(f"Đang tải dataset từ: {data_path}")
    try:
        raw_dataset = load_dataset(data_path, split="train")
    except Exception as e:
        print(f"Lỗi khi tải dataset từ Hub ({data_path}): {e}")
        raise

    def formatting_function(examples):
        texts = []
        for messages in examples["messages"]:
            if len(messages) < 2:
                print(f"Warning! Tin nhắn không hợp lệ: {messages}")
                continue

            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            texts.append(text)

        return {"text": texts}

    if "messages" not in raw_dataset.column_names:
        raise ValueError("Cột 'messages' không tồn tại trong dataset")

    formatted_dataset = raw_dataset.map(
        formatting_function,
        batched=True,
        remove_columns=raw_dataset.column_names,
        desc="Applying chat template (no system instruction)",
        num_proc=4
    )

    print(f"Thành công! Số lượng mẫu: {len(formatted_dataset)}")
    return formatted_dataset


vsl_dataset = format_data_hf(data_path)
data_augment = format_data_hf(data_augment_path)


In [None]:
print(vsl_dataset[0]["text"])
print(data_augment[0]["text"])

In [None]:
def normalize_messages(messages):
    normalized = []
    for m in messages:
        if m["from"] == "human":
            role = "user"
        elif m["from"] == "gpt":
            role = "assistant"
        else:
            continue

        normalized.append({
            "role": role,
            "content": m["value"]
        })
    return normalized


def format_public_ds(data_path):
    print(f"Đang tải dataset từ: {data_path}")
    try:
        raw_dataset = load_dataset(data_path, split="train")
    except Exception as e:
        print(f"Lỗi khi tải dataset từ Hub ({data_path}): {e}")
        raise

    if "conversations" not in raw_dataset.column_names:
        raise ValueError("Cột 'conversations' không tồn tại trong dataset")

    def formatting_function(examples):
        texts = []

        for messages in examples["conversations"]:
            if len(messages) < 2:
                continue

            messages = normalize_messages(messages)

            if len(messages) < 2:
                continue

            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            texts.append(text)

        return {"text": texts}

    formatted_dataset = raw_dataset.map(
        formatting_function,
        batched=True,
        remove_columns=raw_dataset.column_names,
        desc="Applying chat template (alpaca → HF)",
        num_proc=4
    )

    print(f"Thành công! Số lượng mẫu: {len(formatted_dataset)}")
    return formatted_dataset

public_data = format_public_ds(pub_data_path)
print(public_data[0]["text"])


In [None]:
combined_data = concatenate_datasets([vsl_dataset, data_augment, public_data])
print(len(combined_data))

split_dataset = combined_data.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]

print(f"Số mẫu train (90%): {len(train_dataset)}")
print(f"Số mẫu validation (10%): {len(eval_dataset)}")

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=4,
    packing=True,
    args=TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        warmup_steps=100,
        num_train_epochs=2,
        learning_rate=3e-4,
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=100,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=42,
        output_dir="outputs",
        save_strategy="steps",
        save_steps=200,
        eval_strategy="steps",
        eval_steps=200,
        save_total_limit=3,
        report_to="none",
    ),
)

In [None]:
print("Start Trainning.....")
torch.cuda.empty_cache()

trainer_stats = trainer.train()

print(f"Trainning successfull, Total Loss: {trainer_stats.training_loss}")

In [None]:
print("Merging LoRA adapters into base model...")
model = model.merge_and_unload()

model.save_pretrained(repo_name)
tokenizer.save_pretrained(repo_name)

try:
    model.push_to_hub(repo_name, use_temp_dir=True)
    tokenizer.push_to_hub(repo_name, use_temp_dir=True)
    print(f"Training hoàn thành! Model đã được đẩy lên repo: https://huggingface.co/{repo_name}")
except Exception as e:
    print(f"Đã xảy ra lỗi khi đẩy model lên Hub: {e}")
    print(f"Model đã được lưu local tại './{repo_name}'")