# Train the Models

Train 4 different models, on different datasets, based on 80/20 Split set. 
- Using PEFT (QLoRA)

- following the `HuggingFace/Alignment Handbook`. 

- Basemodel: Llama-3.1-8B-Instruct (accessed via unsloth for efficiency). 

- WandB for experiment tracking and visualizations

All specific configurations are found in `notebooks/src/configs.py` and `notebooks/src/model.py`

# 0. Set-up & Configuration

In [None]:
import datasets
import numpy as np
import pandas as pd
import os, sys
from pathlib import Path
from openai import AzureOpenAI
from tqdm import tqdm
from dotenv import load_dotenv


In [None]:
# Install/upgrade unsloth and pin huggingface_hub to a compatible version
!pip install -U "unsloth[colab] @ git+https://github.com/unslothai/unsloth.git"


Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-y2os547u/unsloth_501bce0e07234cf7b974bc1347bb121a
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-y2os547u/unsloth_501bce0e07234cf7b974bc1347bb121a
  Resolved https://github.com/unslothai/unsloth.git to commit 0eebf900ff704810ee62585ba1fa1394baabd0bd


  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hcanceled
[31mERROR: Operation cancelled by user[0m[31m
[0m^C


In [None]:
"""
Configuration dataclasses following alignment-handbook pattern.
All hyperparameters follow MentalChat16K paper (Xu et al., 2025).
"""
from dataclasses import dataclass, field
from typing import Optional, List
from dotenv import load_dotenv

load_dotenv(".env")

# Model Presets - Datasets and Output Directory
MODEL_PRESETS = {
    "modelpilot":{
        "dataset": "ShenLab/MentalChat16K",
        "output_dir": "data/model_pilot-sft-llama-3.1-8b-instruct",
    },
    "model0": {
        "dataset": "ShenLab/MentalChat16K",
        "output_dir": "data/model_0-sft-llama-3.1-8b-instruct",
    },
    "model1": {
        "dataset": "data/ds_generic",
        "output_dir": "data/model_1-sft-llama-3.1-8b-instruct",
    },
    "model2": {
        "dataset": "data/ds_constitution",
        "output_dir": "data/model_2-sft-llama-3.1-8b-instruct",
    },
    "model3": {
        "dataset": "data/ds_constitution_revised",
        "output_dir": "data/model_3-sft-llama-3.1-8b-instruct",
    },
}


# Configuration Dataclasses
@dataclass
class ModelConfig:
    """Model configuration."""
    model_name_or_path: str = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
    model_revision: str = "main"
    torch_dtype: str = "bfloat16"
    attn_implementation: str = "flash_attention_2"
    trust_remote_code: bool = False
    max_seq_length: int = 2048
    load_in_4bit: bool = True


@dataclass 
class LoraConfig:
    """LoRA/QLoRA configuration following MentalChat16K paper."""
    r: int = 64  
    lora_alpha: int = 16  
    lora_dropout: float = 0.1 
    target_modules: List[str] = field(default_factory=lambda: [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ])
    bias: str = "none"
    use_gradient_checkpointing: str = "unsloth"
    random_state: int = 42


@dataclass
class DataConfig:
    """Dataset configuration."""
    dataset_id: str = "ShenLab/MentalChat16K"
    dataset_config: Optional[str] = None
    dataset_split: str = "train"
    test_split_size: float = 0.2
    seed: int = 42
    num_proc: int = 4


@dataclass
class GenerationConfig:
    """Data generation configuration using AZURE OpenAI API."""
    model: str = "o3-mini"
    reasoning_level: str = "medium"       # "low" | "medium" | "high"
    reasoning_summary: str = "concise"    # "auto" | "concise" | "detailed" | "none"
    max_completion_tokens: int = 1500     # includes reasoning tokens; raise if truncated
    # HuggingFace Hub
    hf_username: str = "AIforAlly"
    repo_generic: str = "mentalchat16k-generic-responses"
    repo_constitution: str = "mentalchat16k-constitution-responses"
    # Output paths
    output_dir: str = "data/responses/working_files/"
    generic_csv: str = "response_generic.csv"
    constitution_csv: str = "response_constitution.csv"


@dataclass
class TrainingConfig:
    """Training configuration following MentalChat16K paper."""
    output_dir: str = "data/sft-output"
    
    # Optimizer settings (MentalChat16K paper)
    learning_rate: float = 2.0e-4
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 8
    gradient_accumulation_steps: int = 8  # Effective batch size: 64
    num_train_epochs: int = 5
    max_steps: int = -1
    warmup_ratio: float = 0.03
    weight_decay: float = 0.01
    max_grad_norm: float = 0.3
    lr_scheduler_type: str = "cosine"
    optim: str = "paged_adamw_32bit"
    
    # Evaluation and saving
    eval_strategy: str = "no"  # No eval during training for reproduction
    save_strategy: str = "steps"
    save_steps: int = 100
    save_total_limit: int = 2
    logging_steps: int = 10
    
    # Precision and efficiency
    bf16: bool = True
    fp16: bool = False
    gradient_checkpointing: bool = True
    
    # Hub
    push_to_hub: bool = False
    hub_model_id: Optional[str] = None
    
    # Reporting
    report_to: List[str] = field(default_factory=lambda: ["wandb"])
    seed: int = 42


# Combined Configuration

@dataclass
class SFTScriptConfig:
    """Combined configuration for SFT script."""
    model: ModelConfig = field(default_factory=ModelConfig)
    lora: LoraConfig = field(default_factory=LoraConfig)
    data: DataConfig = field(default_factory=DataConfig)
    generation: GenerationConfig = field(default_factory=GenerationConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    
    @classmethod
    def for_model(cls, model_id: str) -> "SFTScriptConfig":
        """
        Get config for specific model experiment.
        
        Args:
            model_id: One of 'model0', 'model1', 'model2', 'model3'
        
        Returns:
            SFTScriptConfig with model-specific settings applied
        
        Example:
            config = SFTScriptConfig.for_model("model0")
        """
        if model_id not in MODEL_PRESETS:
            raise ValueError(f"Unknown model_id: {model_id}. Choose from {list(MODEL_PRESETS.keys())}")
        
        preset = MODEL_PRESETS[model_id]
        config = cls()
        config.data.dataset_id = preset["dataset"]
        config.training.output_dir = preset["output_dir"]
        return config
    
    def to_dict(self) -> dict:
        """Convert config to dictionary."""
        from dataclasses import asdict
        return asdict(self)

In [None]:
# model.py
"""
Model loading with Unsloth for efficient QLoRA training.
Following alignment-handbook pattern but using Unsloth backend.
"""
import logging
import torch
from typing import Tuple, Optional
from unsloth import FastLanguageModel

logger = logging.getLogger(__name__)


def get_model_and_tokenizer(
    model_name: str,
    max_seq_length: int = 2048,
    dtype: Optional[torch.dtype] = None,
    load_in_4bit: bool = True,
) -> Tuple[FastLanguageModel, any]:
    """
    Load model and tokenizer using Unsloth for efficiency.
    
    Args:
        model_name: HuggingFace model ID or Unsloth optimized model
        max_seq_length: Maximum sequence length
        dtype: Model dtype (None = auto-detect)
        load_in_4bit: Whether to use 4-bit quantization
    
    Returns:
        Tuple of (model, tokenizer)
    """
    logger.info(f"Loading model: {model_name}")
    
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_name,
        max_seq_length=max_seq_length,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
    )
    
    logger.info(f"Model loaded successfully")
    return model, tokenizer


def apply_peft(
    model: FastLanguageModel,
    r: int = 64,  # MentalChat16K paper
    lora_alpha: int = 16,  # MentalChat16K paper
    lora_dropout: float = 0.1,  # MentalChat16K paper
    target_modules: list = None,
    bias: str = "none",
    use_gradient_checkpointing: str = "unsloth",
    random_state: int = 42,
) -> FastLanguageModel:
    """
    Apply PEFT/LoRA configuration using Unsloth's optimized implementation.
    
    Args:
        model: Base model from get_model_and_tokenizer
        r: LoRA rank
        lora_alpha: LoRA alpha parameter
        lora_dropout: Dropout probability
        target_modules: List of modules to apply LoRA
        bias: Bias setting ("none", "all", "lora_only")
        use_gradient_checkpointing: Gradient checkpointing mode
        random_state: Random seed
    
    Returns:
        Model with PEFT applied
    """
    if target_modules is None:
        target_modules = [
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ]
    
    logger.info(f"Applying LoRA with r={r}, alpha={lora_alpha}")
    
    model = FastLanguageModel.get_peft_model(
        model,
        r=r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=target_modules,
        bias=bias,
        use_gradient_checkpointing=use_gradient_checkpointing,
        random_state=random_state,
    )
    
    return model


def prepare_for_inference(model: FastLanguageModel) -> FastLanguageModel:
    """
    Prepare model for inference (2x faster generation).
    
    Args:
        model: Trained model
    
    Returns:
        Model optimized for inference
    """
    FastLanguageModel.for_inference(model)
    return model



ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.


Exception: huggingface-hub>=0.34.0,<1.0 is required for a normal functioning of this module, but found huggingface-hub==0.26.5.
Try: `pip install transformers -U` or `pip install -e '.[dev]'` if you're working with git main

In [None]:
"""
Data loading and preprocessing following alignment-handbook pattern.
"""
import logging
from datasets import load_dataset, DatasetDict
from typing import Optional

logger = logging.getLogger(__name__)


def load_and_split_dataset(
    dataset_id: str,
    dataset_config: Optional[str] = None,
    dataset_split: str = "train",
    test_split_size: float = 0.2,
    seed: int = 42,
    **kwargs
) -> DatasetDict:
    """
    Load and split dataset following alignment-handbook pattern.
    
    Args:
        dataset_id: HuggingFace dataset ID
        dataset_config: Dataset configuration name
        dataset_split: Split to load
        test_split_size: Fraction for test split (0.0 = no split, 0.2 = 20%)
        seed: Random seed for reproducibility
    
    Returns:
        DatasetDict with train (and optionally test) splits
    """
    logger.info(f"Loading dataset: {dataset_id}")
    
    dataset = load_dataset(
        dataset_id,
        dataset_config,
        split=dataset_split
    )
    
    if test_split_size > 0:
        logger.info(f"Splitting dataset with test_size={test_split_size}")
        split = dataset.train_test_split(test_size=test_split_size, seed=seed)
        return DatasetDict({
            'train': split['train'],
            'test': split['test']
        })
    
    return DatasetDict({'train': dataset})


def format_to_messages(example: dict) -> dict:
    """
    Convert dataset format to chat messages.
    Expected input columns: instruction, input, output
    """
    messages = [
        {"role": "system", "content": example.get('instruction', '')},
        {"role": "user", "content": example.get('input', '')},
        {"role": "assistant", "content": example.get('output', '')}
    ]
    return {"messages": messages}


def apply_chat_template(example: dict, tokenizer) -> dict:
    """
    Apply tokenizer's chat template to messages.
    
    Args:
        example: Dict with 'messages' key
        tokenizer: HuggingFace tokenizer with chat_template
    
    Returns:
        Dict with 'text' key containing formatted conversation
    """
    text = tokenizer.apply_chat_template(
        example["messages"],
        tokenize=False,
        add_generation_prompt=False
    )
    return {"text": text}


def prepare_dataset(dataset: DatasetDict, tokenizer, num_proc: int = 4) -> DatasetDict:
    """
    Prepare dataset for SFT training.
    
    Args:
        dataset: Raw dataset with instruction/input/output columns
        tokenizer: Tokenizer for chat template
        num_proc: Number of processes for mapping
    
    Returns:
        Processed dataset with 'text' column ready for SFTTrainer
    """
    logger.info("Formatting dataset to messages...")
    dataset = dataset.map(format_to_messages, num_proc=num_proc)
    
    logger.info("Applying chat template...")
    dataset = dataset.map(
        lambda x: apply_chat_template(x, tokenizer),
        num_proc=num_proc
    )
    
    return dataset

In [None]:
"""
SFT training with Optuna hyperparameter optimization.
Following alignment-handbook pattern with Unsloth efficiency.
"""
import logging
import os
from typing import Optional, Dict, Any

from trl import SFTTrainer, SFTConfig
from transformers import set_seed

logger = logging.getLogger(__name__)


def create_training_args(
    training_cfg,  # TrainingConfig
    max_seq_length: int,
) -> SFTConfig:
    """
    Create SFT training arguments from TrainingConfig.
    
    Args:
        training_cfg: TrainingConfig dataclass instance
        max_seq_length: From ModelConfig (needed for SFTConfig)
    
    Returns:
        SFTConfig for TRL SFTTrainer
    """
    return SFTConfig(
        output_dir=training_cfg.output_dir,
        learning_rate=training_cfg.learning_rate,
        per_device_train_batch_size=training_cfg.per_device_train_batch_size,
        per_device_eval_batch_size=training_cfg.per_device_eval_batch_size,
        gradient_accumulation_steps=training_cfg.gradient_accumulation_steps,
        num_train_epochs=training_cfg.num_train_epochs,
        max_seq_length=max_seq_length,
        eval_strategy=training_cfg.eval_strategy,
        save_strategy=training_cfg.save_strategy,
        save_steps=training_cfg.save_steps,
        logging_steps=training_cfg.logging_steps,
        warmup_ratio=training_cfg.warmup_ratio,
        weight_decay=training_cfg.weight_decay,
        max_grad_norm=training_cfg.max_grad_norm,
        lr_scheduler_type=training_cfg.lr_scheduler_type,
        optim=training_cfg.optim,
        bf16=training_cfg.bf16,
        fp16=training_cfg.fp16,
        gradient_checkpointing=training_cfg.gradient_checkpointing,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        save_total_limit=training_cfg.save_total_limit,
        seed=training_cfg.seed,
        dataset_text_field="text",
        packing=False,
        report_to=training_cfg.report_to,
    )


def create_trainer(
    model,
    tokenizer,
    train_dataset,
    eval_dataset,
    training_args: SFTConfig,
) -> SFTTrainer:
    """
    Create SFT trainer following alignment-handbook pattern.
    """
    return SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer,
    )


def train(
    config: SFTScriptConfig
) -> Dict[str, Any]:
    """
    Main training function.
    
    Args:
        config: Full SFT configuration
        trial: Optional Optuna trial for hyperparameter search
    
    Returns:
        Dictionary with training results
    """
    set_seed(config.training.seed)
    
    # Load model and tokenizer
    model, tokenizer = get_model_and_tokenizer(
        model_name=config.model.model_name_or_path,
        max_seq_length=config.model.max_seq_length,
        load_in_4bit=config.model.load_in_4bit,
    )
    
    # Apply PEFT
    model = apply_peft(
        model,
        r=config.lora.r,
        lora_alpha=config.lora.lora_alpha,
        lora_dropout=config.lora.lora_dropout,
        target_modules=config.lora.target_modules,
        bias=config.lora.bias,
        use_gradient_checkpointing=config.lora.use_gradient_checkpointing,
        random_state=config.lora.random_state,
    )
    
    # Load and prepare dataset
    dataset = load_and_split_dataset(
        dataset_id=config.data.dataset_id,
        dataset_config=config.data.dataset_config,
        dataset_split=config.data.dataset_split,
        test_split_size=config.data.test_split_size,
        seed=config.data.seed,
    )
    dataset = prepare_dataset(dataset, tokenizer, num_proc=config.data.num_proc)
    
    # Create training arguments
    training_args = create_training_args(
        training_cfg=config.training,
        max_seq_length=config.model.max_seq_length,
    )
    
    # Create trainer
    trainer = create_trainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset["train"],
        eval_dataset=dataset.get("test"),
        training_args=training_args,
    )
    
    # Train
    logger.info("*** Starting training ***")
    train_result = trainer.train()
    
    # Evaluate
    metrics = {}
    if dataset.get("test") is not None:
        eval_metrics = trainer.evaluate()
        metrics["eval_loss"] = eval_metrics["eval_loss"]
    
    metrics["train_loss"] = train_result.training_loss
    
    return {
        "trainer": trainer,
        "model": model,
        "tokenizer": tokenizer,
        "metrics": metrics,
    }

In [None]:
import logging
import os
import sys
import torch

# Configure logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.10.0+cu128
CUDA available: True


In [None]:
config = SFTScriptConfig.for_model("modelpilot")

# Display configuration
print("Model Config")
print(f"  Model: {config.model.model_name_or_path}")
print(f"  Max seq length: {config.model.max_seq_length}")
print(f"  Load in 4-bit: {config.model.load_in_4bit}")

print("\nLoRA Config")
print(f"  Rank (r): {config.lora.r}")
print(f"  Alpha: {config.lora.lora_alpha}")
print(f"  Dropout: {config.lora.lora_dropout}")

print("\nData Config")
print(f"  Dataset: {config.data.dataset_id}")
print(f"  Test split size: {config.data.test_split_size}")

print("\nTraining Config")
print(f"  Output dir: {config.training.output_dir}")
print(f"  Learning rate: {config.training.learning_rate}")
print(f"  Batch size: {config.training.per_device_train_batch_size}")
print(f"  Epochs: {config.training.num_train_epochs}")

Model Config
  Model: unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit
  Max seq length: 2048
  Load in 4-bit: True

LoRA Config
  Rank (r): 64
  Alpha: 16
  Dropout: 0.1

Data Config
  Dataset: ShenLab/MentalChat16K
  Test split size: 0.2

Training Config
  Output dir: data/model_pilot-sft-llama-3.1-8b-instruct
  Learning rate: 0.0002
  Batch size: 8
  Epochs: 5


In [None]:
import wandb
wandb.login()

# Initialize run
wandb.init(
    entity="alha8035-stockholm-university",
    project="pilot_model0_sft",
    config=config.to_dict(),
    tags=["sft", "qlora", "unsloth"],
)

wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice:wandb: You chose 'Use an existing W&B account'
wandb: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: Create a new API key at: https://wandb.ai/authorize?ref=models
wandb: Store your API key securely and do not share it.
wandb: Paste your API key and hit enter:wandb: ERROR Invalid API key: API key must have 40+ characters, has 1.
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice:wandb: You chose 'Use an existing W&B account'
wandb: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
wandb: Create a new API key at: https://wandb.ai/authorize?ref=models
wandb: Store your API key securely and do not share it.
wandb: Paste your API key and hit enter:wan

wandb: Detected [huggingface_hub.inference, openai] in use.
wandb: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
wandb: For more information, check out the docs at: https://weave-docs.wandb.ai/


# 1. Training Set-Up 

Following the Parameters used in the MentalChat16K Paper (Xu et al., 2025). 

In [None]:
model, tokenizer = get_model_and_tokenizer(
    model_name=config.model.model_name_or_path,
    max_seq_length=config.model.max_seq_length,
    load_in_4bit=config.model.load_in_4bit,
)

print(f"Model loaded: {config.model.model_name_or_path}")
print(f"Model dtype: {model.dtype}")
print(f"Tokenizer vocab size: {len(tokenizer)}")

==((====))==  Unsloth 2026.2.1: Fast Llama patching. Transformers: 4.57.6.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.563 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.6.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.35. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


AttributeError: module 'huggingface_hub.constants' has no attribute 'HF_HUB_ENABLE_HF_TRANSFER'

In [None]:
#Apply PEFT/LoRA using Unsloth
model = apply_peft(
    model,
    r=config.lora.r,
    lora_alpha=config.lora.lora_alpha,
    lora_dropout=config.lora.lora_dropout,
    target_modules=config.lora.target_modules,
    bias=config.lora.bias,
    use_gradient_checkpointing=config.lora.use_gradient_checkpointing,
    random_state=config.lora.random_state,
)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")

# 2. Training

Training four different models, based on different datasets 

    - Model 0: MentalChat16K Dataset 

    - Model 0: ds_generic 

    - Model 1: ds_act 

    - Model 2: ds_constitution

## 2.0 Pilot test

In [None]:
dataset = load_and_split_dataset(
    dataset_id=config.data.dataset_id,
    dataset_config=config.data.dataset_config,
    dataset_split=config.data.dataset_split,
    test_split_size=config.data.test_split_size,
    seed=config.data.seed,
)

#pilot testing, small subset
TRAIN_SUBSET = 800
TEST_SUBSET = 200

dataset["train"] = dataset["train"].select(range(min(TRAIN_SUBSET, len(dataset["train"]))))
if "test" in dataset:
    dataset["test"] = dataset["test"].select(range(min(TEST_SUBSET, len(dataset["test"]))))

# Prepare dataset (format to messages, apply chat template)
dataset = prepare_dataset(dataset, tokenizer, num_proc=config.data.num_proc)

print(f"Train samples: {len(dataset['train'])}")
print(f"Test samples: {len(dataset.get('test', []))}")


In [None]:
# Colab uses fp16 instead of bf16
config.training.bf16 = False
config.training.fp16 = True

# create_training_args takes the full TrainingConfig object
training_args = create_training_args(
    training_cfg=config.training,
    max_seq_length=config.model.max_seq_length,
)

trainer = create_trainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    eval_dataset=dataset.get("test"),
    training_args=training_args,
)

print("Trainer created successfully!")

NameError: name 'create_training_args' is not defined

In [None]:
# Run Training
print("Starting training...")
train_result = trainer.train()

# Log final metrics
print(f"\nTraining Complete")
print(f"Final train loss: {train_result.training_loss:.4f}")

# Evaluate if test set exists
if dataset.get("test") is not None:
    eval_metrics = trainer.evaluate()
    print(f"Eval loss: {eval_metrics['eval_loss']:.4f}")

## 2.1 Model 0: MentalChat16K Dataset

## 2.2 Model 1: ds_generic

## 2.3 Model 2: ds_act

## 2.4 Model 3: ds_constitution

# 3. Quick Check & Save the Models 

In [None]:
# Save model and tokenizer locally
OUTPUT_DIR = config.training.output_dir
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Model saved locally to: {OUTPUT_DIR}")

# Save as WandB artifact
artifact = wandb.Artifact(
    name="model_pilot-sft-llama-3.1-8b-instruct",
    type="model",
    description="Pilot SFT model fine-tuned on MentalChat16K with QLoRA",
    metadata=config.to_dict(),
)
artifact.add_dir(OUTPUT_DIR)
wandb.log_artifact(artifact)
print(f"Model uploaded to WandB as artifact: model_pilot-sft-llama-3.1-8b-instruct")

# Finish WandB run
wandb.finish()


In [None]:
from unsloth import FastLanguageModel

FastLanguageModel.for_inference(model)  # prepares model for 2x faster inference

# Test prompt
test_input = "I've been feeling really anxious lately about my job. I keep thinking I'm going to get fired even though there's no evidence of that."
system_prompt = "You are a helpful mental health counselling assistant, please answer the mental health questions based on the patient's description.  The assistant gives helpful, comprehensive, and appropriate answers to the user's questions."

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": test_input}
]

# Apply chat template
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
    )

response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(f"\nUser Input: {test_input}")
print(f"\nModel Response: {response}")