In [1]:
import torch

from datasets import Dataset

from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
from webbot_datasets.model_prompts import bot_prompt, filter_prompt
from webbot_datasets.training_datasets import bot_dataset, bot_verification_dataset, filter_dataset, filter_verification_dataset

In [5]:
import os
from pathlib import Path
os.makedirs(Path(os.environ['STORAGE_DIR'], "cache"), exist_ok=True)

def init_model_and_tokenizer(max_seq_length):
    #model_name = "unsloth/Llama-3.2-3B",
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name= "/app/shared_storage/llama3.2_3b_webbot", # "unsloth/Llama-3.2-3B-bnb-4bit",
        max_seq_length = max_seq_length,
        load_in_4bit = True, # reduce memory usage
        cache_dir=Path(os.environ['STORAGE_DIR'], "cache").as_posix()
    )
    return model, tokenizer


def create_dataset(ds, eos_token, prompt):
    d = {
        "text": [],
        "user": [],
        "assistant": []
    }
    for tune in ds:
        d['user'].append(prompt.format(query=tune["question"]) + eos_token)
        d['assistant'].append(tune["answer"] + eos_token)
        sample = prompt.format(query=tune["question"]) + tune["answer"]
        d['text'].append(sample + eos_token)
    return Dataset.from_dict(d)


def log_gpu_stats():
    gpu_stats = torch.cuda.get_device_properties(0)
    start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    print("-"*80)
    print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
    print(f"{start_gpu_memory} GB of memory reserved.")
    print("-"*80)
    
    return start_gpu_memory, max_memory


def log_gpu_usage(start_gpu_memory, max_memory, trainer_stats):
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
    used_percentage = round(used_memory         /max_memory*100, 3)
    lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
    
    print("-"*80)
    print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
    print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
    print(f"Peak reserved memory = {used_memory} GB.")
    print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
    print(f"Peak reserved memory % of max memory = {used_percentage} %.")
    print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
    print("-"*80)

In [None]:
if not torch.cuda.is_available():
    print("CUDA is not available. Exiting.")

max_seq_length = 1024
model, tokenizer = init_model_and_tokenizer(max_seq_length)

In [36]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                        "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = True, # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

In [None]:
fds = create_dataset(filter_dataset, tokenizer.eos_token, filter_prompt)
f_steps = 30
# loss should become ~ 0.1

ds = create_dataset(bot_dataset, tokenizer.eos_token, bot_prompt)
d_steps = 120
# loss should become ~ 0.012100

# select config
train_dataset = ds
max_steps = d_steps

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = max_steps,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

start_gpu_memory, max_memory = log_gpu_stats()

In [None]:
trainer_stats = trainer.train()

In [None]:
log_gpu_usage(start_gpu_memory, max_memory, trainer_stats)

In [7]:
FastLanguageModel.for_inference(model)

def query_model(prompt: str, query: str):
    input_tokens = tokenizer(
        [prompt.format(query=query)],
        add_special_tokens = False,
        return_tensors = "pt"
    ).to("cuda")

    gen_ids = model.generate(
        **input_tokens,
        max_new_tokens = 64,
        use_cache=True,
        temperature = 0.5,
        min_p = 0.1
    )

    output = tokenizer.decode(
        gen_ids[:, input_tokens['input_ids'].shape[1]:][0],
        skip_prompt = True,
        skip_special_tokens = True)

    return output.strip()

In [None]:
def verify_filter():
    prompt = filter_prompt
    for q in filter_verification_dataset:
        output = query_model(prompt, q['query'])
        if output != q['answer']:
            print(f"FAILED: {output} - {q['query']}")

verify_filter()

# if happy save :)
# model.save_pretrained(Path(os.environ['STORAGE_DIR'], "llama3.2_3b_webfilter").as_posix())
# tokenizer.save_pretrained(Path(os.environ['STORAGE_DIR'], "llama3.2_3b_webfilter").as_posix())

In [None]:
# basic verification queries
for query in bot_verification_dataset:
    output = query_model(bot_prompt, query)
    print(output)
    print("-"*80)

# if happy save :)
# model.save_pretrained(Path(os.environ['STORAGE_DIR'], "llama3.2_3b_webbot").as_posix())
# tokenizer.save_pretrained(Path(os.environ['STORAGE_DIR'], "llama3.2_3b_webbot").as_posix())