In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [2]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps --upgrade transformers # Only for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N

In [3]:
from unsloth import FastModel
import torch

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E2B-it", # Or "unsloth/gemma-3n-E4B-it"
    dtype = None, # None for auto detection
    max_seq_length = 4096, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.6.0+cu124 with CUDA 1204 (you have 2.8.0+cu126)
    Python  3.12.9 (you have 3.12.11)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.9: Fast Gemma3N patching. Transformers: 4.55.4.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/2.65G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/469M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

In [4]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_nlp_modules       = True,  # Should leave on always!

    r = 16,           # Larger = higher accuracy, but might overfit
    lora_alpha = 16,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


In [7]:
# Load a structured chat prompt format into your tokenizer
from unsloth.chat_templates import get_chat_template as attach_template

def format_tokenizer_for_chat(tok, template="gemma-3"):
    """Adds a conversational structure to the tokenizer for chat-style input."""
    return attach_template(tok, chat_template=template)

tokenizer = format_tokenizer_for_chat(tokenizer)

In [11]:

from datasets import load_dataset, Dataset
import random, json

def preview_dataset(name, subset, max_rows=None):
    """Utility to load and preview a HuggingFace dataset subset."""
    data = load_dataset(name, split=subset)
    if max_rows:
        data = data.select(range(max_rows))
    print(f" Loaded '{name}' with {len(data)} entries.")
    return data

# Psychology dataset → distinguishing empathy vs. dismissive responses
psych_data = preview_dataset("jkhedri/psychology-dataset", "train", max_rows=3000)

# Coaching dataset → distinguishing questioning vs. directive leadership styles
coach_data = preview_dataset("drublackberry/hbr-coaching-real-leaders", "train")


 Loaded 'jkhedri/psychology-dataset' with 3000 entries.


Resolving data files:   0%|          | 0/40 [00:00<?, ?it/s]

 Loaded 'drublackberry/hbr-coaching-real-leaders' with 40 entries.


In [14]:
# %%
# Parsing datasets and preparing for DPO conversion
from datasets import Dataset
import gc

print("Step 1: Parsing both datasets...")

# Function to process the psychology dataset
def process_psychology_data(psych_source, max_samples=3000):
    """Process psychology dataset and prepare for DPO"""

    print(f"Psychology dataset: {len(psych_source)} total rows, using {max_samples}")

    parsed_samples = []

    # Coaching prompts to enrich psychology responses
    coaching_prompts = [
        "What feels most important to you as you think about this?",
        "What would you like to focus on moving forward?",
        "How would you like to approach this situation?",
        "What support would be most helpful for you right now?",
        "What small step could you take today?",
        "What matters most to you in this situation?",
        "How do you want to move forward with this?",
        "What would success look like for you here?"
    ]

    for i in range(min(max_samples, len(psych_source))):
        if i % 500 == 0:
            print(f"Processing psychology: {i}/{max_samples}")
            gc.collect()

        row = psych_source[i]

        # Add enrichment to good response
        prompt_addition = coaching_prompts[i % len(coaching_prompts)]
        enriched_response = f"{row['response_j']}\n\n{prompt_addition}"

        parsed_samples.append({
            "prompt": row['question'],
            "chosen": enriched_response,
            "rejected": row['response_k']
        })

    return parsed_samples

# Function to process coaching dataset
def process_coaching_data(coach_source):
    """Extract useful coaching pairs from dataset"""

    print(f"Coaching dataset: {len(coach_source)} rows")

    coaching_pairs = []

    for idx, row in enumerate(coach_source):
        messages = row['messages']

        for i in range(len(messages) - 1):
            current = messages[i]
            next_msg = messages[i + 1]

            # Ensure valid user → assistant exchange
            if (
                current.get('role') == 'user'
                and next_msg.get('role') == 'assistant'
                and len(current.get('content', '')) > 50
                and len(next_msg.get('content', '')) > 30
                and '?' in next_msg.get('content', '')
            ):
                client_text = current['content'].strip()
                coach_text = next_msg['content'].strip()

                # Artificial "bad coaching" response
                poor_response = f"You should just {client_text.lower().split()[0]} differently. Stop overthinking and take action immediately."

                coaching_pairs.append({
                    "prompt": client_text,
                    "chosen": coach_text,
                    "rejected": poor_response
                })

    print(f"Extracted {len(coaching_pairs)} coaching pairs")
    return coaching_pairs

# Run parsing
psychology_samples = process_psychology_data(psych_data, max_samples=3000)
coaching_samples = process_coaching_data(coach_data)

# Combine
all_samples = psychology_samples + coaching_samples
print(f"Total parsed examples: {len(all_samples)}")

# Step 2: Convert to DPO format
print("Step 2: Converting to DPO format...")

def convert_to_dpo(samples):
    """Convert parsed data into DPO-compatible format"""

    formatted = []

    for i, ex in enumerate(samples):
        if i % 500 == 0:
            print(f"Converting: {i}/{len(samples)}")
            gc.collect()

        formatted.append({
            "prompt": ex['prompt'],
            "chosen": ex['chosen'],
            "rejected": ex['rejected']
        })

    return formatted

# Convert and finalize
dpo_data = convert_to_dpo(all_samples)
final_dataset = Dataset.from_list(dpo_data)

# Cleanup memory
del psych_data, coach_data, psychology_samples, coaching_samples, all_samples, dpo_data
gc.collect()

print(f"Final DPO dataset ready: {len(final_dataset)} examples")


Step 1: Parsing both datasets...
Psychology dataset: 3000 total rows, using 3000
Processing psychology: 0/3000
Processing psychology: 500/3000
Processing psychology: 1000/3000
Processing psychology: 1500/3000
Processing psychology: 2000/3000
Processing psychology: 2500/3000
Coaching dataset: 40 rows
Extracted 789 coaching pairs
Total parsed examples: 3789
Step 2: Converting to DPO format...
Converting: 0/3789
Converting: 500/3789
Converting: 1000/3789
Converting: 1500/3789
Converting: 2000/3789
Converting: 2500/3789
Converting: 3000/3789
Converting: 3500/3789
Final DPO dataset ready: 3789 examples


In [27]:
# data_utils.py
import os
import gc
from datasets import load_dataset, Dataset


def load_and_validate_datasets():
    """
    Load psychology and coaching datasets, then clean and validate them
    for ultra-safe training.

    Returns:
        ultra_safe_dataset: HuggingFace Dataset, cleaned and ready for training
    """
    # Force single-threaded processing
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    os.environ["OMP_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"

    print("[INFO] Loading psychology dataset...")
    psych_dataset = load_dataset("jkhedri/psychology-dataset", split="train[:3000]")

    print("[INFO] Loading executive coaching dataset...")
    coaching_dataset = load_dataset("drublackberry/hbr-coaching-real-leaders", split="train")

    print(f"[INFO] Psychology dataset: {len(psych_dataset)} examples")
    print(f"[INFO] Coaching dataset: {len(coaching_dataset)} examples")

    # Merge datasets if needed (example: append coaching to psychology)
    combined_dataset = psych_dataset + coaching_dataset

    print("[INFO] Validating and cleaning combined dataset...")

    clean_examples = []
    for i, example in enumerate(combined_dataset):
        if (example and isinstance(example, dict) and
            example.get("prompt") and example.get("chosen") and example.get("rejected") and
            isinstance(example["prompt"], str) and isinstance(example["chosen"], str) and isinstance(example["rejected"], str) and
            len(example["prompt"].strip()) > 5 and len(example["chosen"].strip()) > 5 and len(example["rejected"].strip()) > 5
           ):
            clean_examples.append({
                "prompt": example["prompt"].strip(),
                "chosen": example["chosen"].strip(),
                "rejected": example["rejected"].strip()
            })
        else:
            print(f"[WARN] Skipping invalid example {i}")

    ultra_safe_dataset = Dataset.from_list(clean_examples)

    # Clear memory
    del combined_dataset, clean_examples, psych_dataset, coaching_dataset
    gc.collect()

    print(f"[INFO] Ultra-safe dataset ready: {len(ultra_safe_dataset)} examples")
    return ultra_safe_dataset


In [32]:
from datasets import Dataset

# Example: create a Dataset from a list of dictionaries
# Replace this with your real dataset
my_data_list = [
    {"prompt": "Hello, how are you?", "chosen": "I'm good, thanks!"},
    {"prompt": "What's AI?", "chosen": "AI is artificial intelligence."},
]

# Convert to Hugging Face Dataset
cleaned = Dataset.from_list(my_data_list)


In [29]:
# Ultra-safe DPO training: single-threaded, stable configuration
from trl import DPOTrainer, DPOConfig

print("Starting ultra-safe DPO training...")

try:
    trainer = DPOTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=ultra_safe_dataset,  # Use cleaned dataset
        args=DPOConfig(
            output_dir="./wellness-coach-safe",
            per_device_train_batch_size=1,
            gradient_accumulation_steps=1,
            max_steps=100,
            learning_rate=1e-6,
            warmup_steps=10,
            logging_steps=10,
            save_steps=25,
            optim="adamw_8bit",
            gradient_checkpointing=True,
            fp16=True,

            # Single-threaded for stability
            dataloader_num_workers=0,
            preprocessing_num_workers=1,
            dataloader_persistent_workers=False,
            dataloader_pin_memory=False,

            # DPO-specific parameters
            beta=0.1,
            loss_type="sigmoid",

            # Other safety settings
            seed=3407,
            report_to="none",
            disable_tqdm=False,
            remove_unused_columns=True,
        ),
    )

    print("Trainer configured successfully.")
    print(f"Training on {len(ultra_safe_dataset)} examples.")

    # Start training
    stats = trainer.train()

    print("Training completed successfully.")
    print(f"Training runtime: {round(stats.metrics['train_runtime']/60, 2)} minutes.")

    # Save trained model and tokenizer
    model.save_pretrained("wellness-coach-trained")
    tokenizer.save_pretrained("wellness-coach-trained")
    print("Model saved successfully.")

except Exception as e:
    print(f"DPO training failed: {e}")
    print("Attempting fallback SFT training...")


Starting ultra-safe DPO training...
DPO training failed: name 'ultra_safe_dataset' is not defined
Attempting fallback SFT training...


In [33]:
# %%
# ULTRA-SAFE TRAINING PIPELINE: DPO first, SFT fallback
from trl import DPOTrainer, DPOConfig, SFTTrainer, SFTConfig
from datasets import Dataset
from unsloth.chat_templates import standardize_data_formats

# -------------------------------
# STEP 0: Define your dataset
# Replace with your preprocessed dataset variable if different
try:
    data = cleaned  # Your preprocessed dataset variable
except NameError:
    raise NameError("Dataset variable 'cleaned' not found. Assign your preprocessed dataset to 'cleaned'.")

# -------------------------------
# STEP 1: ULTRA-SAFE DPO TRAINING
print("Starting ultra-safe DPO training...")

try:
    dpo_trainer = DPOTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=data,
        args=DPOConfig(
            output_dir="./wellness-coach-safe",
            per_device_train_batch_size=1,
            gradient_accumulation_steps=1,
            max_steps=100,
            learning_rate=1e-6,
            warmup_steps=10,
            logging_steps=10,
            save_steps=25,
            optim="adamw_8bit",
            gradient_checkpointing=True,
            fp16=True,
            dataloader_num_workers=0,
            preprocessing_num_workers=1,
            dataloader_persistent_workers=False,
            dataloader_pin_memory=False,
            beta=0.1,
            loss_type="sigmoid",
            seed=3407,
            report_to="none",
            disable_tqdm=False,
            remove_unused_columns=True,
        ),
    )

    print("Ultra-safe DPO trainer configured.")
    print(f"Training {len(data)} examples.")

    trainer_stats = dpo_trainer.train()

    print("DPO training completed successfully.")
    print(f"Training time: {round(trainer_stats.metrics['train_runtime']/60, 2)} minutes")

    model.save_pretrained("wellness-coach-trained")
    tokenizer.save_pretrained("wellness-coach-trained")
    print("Model saved successfully.")

except Exception as e:
    print(f"DPO training failed: {e}")
    print("Attempting fallback SFT training...")

    # -------------------------------
    # STEP 2: SFT FALLBACK TRAINING
    def convert_dpo_to_sft(dpo_dataset):
        sft_examples = []
        for example in dpo_dataset:
            conversation = [
                {"role": "user", "content": example['prompt']},
                {"role": "assistant", "content": example['chosen']}
            ]
            sft_examples.append({"conversations": conversation})
        return sft_examples

    sft_data = convert_dpo_to_sft(data)
    sft_dataset = Dataset.from_list(sft_data)

    sft_dataset = standardize_data_formats(sft_dataset)

    def formatting_prompts_func(examples):
        convos = examples["conversations"]
        texts = [
            tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False).removeprefix('<bos>')
            for convo in convos
        ]
        return {"text": texts}

    sft_dataset = sft_dataset.map(formatting_prompts_func, batched=True)

    sft_trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=sft_dataset,
        args=SFTConfig(
            dataset_text_field="text",
            per_device_train_batch_size=1,
            gradient_accumulation_steps=2,
            warmup_steps=10,
            max_steps=100,
            learning_rate=2e-6,
            logging_steps=10,
            optim="adamw_8bit",
            weight_decay=0.01,
            lr_scheduler_type="linear",
            seed=3407,
            report_to="none",
            output_dir="./wellness-coach-sft",
            dataloader_num_workers=0,
            remove_unused_columns=True,
            gradient_checkpointing=True,
            fp16=True,
        ),
    )

    print("SFT trainer configured as fallback.")
    print("SFT will train on good responses only.")

    # Uncomment to start fallback training
    # sft_trainer.train()


Starting ultra-safe DPO training...
DPO training failed: DPOConfig.__init__() got an unexpected keyword argument 'preprocessing_num_workers'
Attempting fallback SFT training...


Unsloth: Standardizing formats (num_proc=2):   0%|          | 0/2 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/2 [00:00<?, ? examples/s]

SFT trainer configured as fallback.
SFT will train on good responses only.


In [34]:
"""### Train the Enhanced Wellness Coach"""

print(" Starting Enhanced Wellness Coach DPO Training...")
print("Learning: Empathy vs Dismissive + Questions vs Directive advice")
print("  This will take some time on T4...")

trainer_stats = sft_trainer.train()

 Starting Enhanced Wellness Coach DPO Training...
Learning: Empathy vs Dismissive + Questions vs Directive advice
  This will take some time on T4...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 2 | Num Epochs = 100 | Total steps = 100
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 2 x 1) = 2
 "-____-"     Trainable parameters = 21,135,360 of 5,460,573,632 (0.39% trained)


Step,Training Loss
10,15.2644
20,14.1238
30,1.5469
40,0.0521
50,0.0325
60,0.0323
70,0.0323
80,0.0323
90,0.0323
100,0.0323


In [35]:
# %%
# TEST: Check if the model actually learned anything
def test_wellness_coach():
    """Test the trained wellness coach model on example prompts"""

    test_prompts = [
        "I've been feeling really anxious about my job performance lately.",
        "I can't seem to stick to my exercise routine.",
        "I feel overwhelmed with everything in my life right now.",
        "I'm torn between a safe career move and a risky opportunity."
    ]

    print("Testing trained wellness coach model...")

    for idx, prompt in enumerate(test_prompts):
        messages = [{
            "role": "user",
            "content": [{"type": "text", "text": prompt}]
        }]

        # Prepare model inputs
        inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt",
            tokenize=True,
            return_dict=True,
        ).to("cuda")

        print(f"\n{'='*60}")
        print(f"Test {idx+1}: {prompt}")
        print("Response: ", end="")

        from transformers import TextStreamer

        _ = model.generate(
            **inputs,
            max_new_tokens=100,
            temperature=0.7,
            top_p=0.9,
            top_k=50,
            streamer=TextStreamer(tokenizer, skip_prompt=True),
        )

# Execute the test
test_wellness_coach()


Testing trained wellness coach model...

Test 1: I've been feeling really anxious about my job performance lately.
Response: It's completely understandable to feel anxious about your job performance. That's a very common experience, and it's brave of you to acknowledge it.  Let's break this down.  I want to help you explore this feeling.  

First, **it's okay to feel anxious.**  It doesn't mean you're failing or doing something wrong.  It's a natural human response to pressure and uncertainty.

To help me understand what's going

Test 2: I can't seem to stick to my exercise routine.
Response: It's incredibly common to struggle with sticking to an exercise routine! It's a really challenging goal, and it's great that you're recognizing it.  Let's break down why this happens and explore some strategies to help you get back on track.  I'll cover common reasons for difficulty, practical tips, and mindset shifts.



**1. Understanding Why You're Struggling**

Before we jump into solutions, l

In [37]:
# %%
# ENHANCED MULTI-PROMPT TEST: Creative & consoling responses
def test_multiple_creative_responses():
    """Test multiple prompts and store enhanced responses"""

    test_prompts = [
        "I'm feeling burned out at work and it's affecting my relationships.",
        "I feel anxious and unsure about my career path.",
        "Lately, I can't find motivation to maintain my daily routines.",
        "I feel overwhelmed trying to balance personal and professional life."
    ]

    from transformers import TextStreamer

    results = []  # Store prompts and responses

    for i, prompt in enumerate(test_prompts):
        messages = [{
            "role": "user",
            "content": [{"type": "text", "text": prompt}]
        }]

        inputs = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt",
            tokenize=True,
            return_dict=True,
        ).to("cuda")

        print(f"\n{'='*60}\nTest {i+1}: {prompt}\nEnhanced Response: ", end="")

        # Use a streamer to display in real-time while also capturing
        streamer = TextStreamer(tokenizer, skip_prompt=True)
        output = model.generate(
            **inputs,
            max_new_tokens=250,
            temperature=0.85,
            top_p=0.95,
            top_k=60,
            streamer=streamer,
        )

        # Store the prompt and response
        results.append({
            "prompt": prompt,
            "response": tokenizer.decode(output[0], skip_special_tokens=True)
        })

    return results

# Run the multi-prompt test
enhanced_responses = test_multiple_creative_responses()

# Example: Access the first response
print("\nFirst stored response:\n", enhanced_responses[0])



Test 1: I'm feeling burned out at work and it's affecting my relationships.
Enhanced Response: Okay, it sounds like you're dealing with a really tough situation. Burnout impacting your work and relationships is a serious issue, and it's good that you're recognizing it and seeking help.  Let's break this down.  I want to help you explore this further and brainstorm some potential solutions.  I'll try to give you a structured approach, but remember I'm an AI and can't provide professional advice.  This is for informational and supportive purposes only.

Here's a breakdown of how we can approach this, and what I can help you with:

**1. Understanding the Burnout:**

*   **What specifically is causing the burnout?**  (e.g., workload, lack of control, unfairness, lack of reward, not being valued, constant pressure, difficult relationships with colleagues/clients, lack of support).  The more specific you can be, the better.
*   **How long have you been feeling this way?**  (e.g., a few week