In [None]:
from datasets import load_dataset
import pandas as pd
import torch


In [None]:
from unsloth import FastModel

max_seq_length = 2048

model, tokenizer = FastModel.from_pretrained(
    model_name="unsloth/gemma-3-270m-it",
    max_seq_length=max_seq_length,
    
)

In [None]:
model = FastModel.get_peft_model(
    model,
    r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 128,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)



In [None]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma3",
)

In [None]:
ds = load_dataset("Amod/mental_health_counseling_conversations", split='train')

In [None]:

def convert_to_chatml(text):
    return {
        "conversation": [
            {
                "role": "system",
                "content": "You are a mental health counselor. You are given a conversation between a patient and a counselor. You need to answer the patient's question based on the conversation.",
            },
            {
                "role": "user",
                "content": text['Context'],
            },
            {
                "role": "assistant",
                "content": text['Response'],
            }
            
        ]
    }

dataset = ds.map(
    convert_to_chatml
)

dataset[0]


In [None]:
def formatting_prompts_func(examples):
   convos = examples["conversation"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
   return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True)

In [None]:
dataset[0]['text']

In [None]:
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 8,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size!
        warmup_steps = 10,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 100,
        learning_rate = 5e-5, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 42,
        output_dir="../models",
        report_to = "none", # Use TrackIO/WandB etc
    ),
)

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

In [None]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

In [None]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

In [None]:
trainer_stats = trainer.train()