# SFT Training Pipeline
In addition to the HuggingFace/Alignment Handbook, the following "tools" were used:
- **Unsloth** for faster training with less memory
- **QLoRA** for parameter-efficient fine-tuning
- **Optuna** for hyperparameter optimization
- **WandB** for experiment tracking and visualization

## Pipeline Structure
1. Setup & Configuration
2. Load Model & Tokenizer (Unsloth)
3. Prepare Dataset
4. Train Model
5. Hyperparameter Search
6. Save & Test Model

## Set Up & Import
Clone GitHub Repo for it to be run on GoogleColab

In [1]:
!git clone https://github.com/Ally-Ha/pilot_act-cai_model0_SFT.git
%cd pilot_act-cai_model0_SFT
!pip install -r requirements.txt


fatal: destination path 'pilot_act-cai_model0_SFT' already exists and is not an empty directory.
/content/pilot_act-cai_model0_SFT


In [2]:
import logging
import os
import sys
import torch

# Configure logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

from src import (
    SFTScriptConfig,
    get_model_and_tokenizer,
    apply_peft,
    load_and_split_dataset,
    prepare_dataset,
    create_training_args,
    create_trainer,
    train,
    run_hpo,
    prepare_for_inference,
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
PyTorch version: 2.10.0+cu128
CUDA available: True


## 1. Configuration
The config follows alignment-handbook's structure with sections for: model, lora, data, training.

In [3]:
CONFIG_PATH = "recipes/SFT/config_pilot.yaml"
config = SFTScriptConfig.from_yaml(CONFIG_PATH)

# Display configuration
print("Model Config")
print(f"  Model: {config.model.model_name_or_path}")
print(f"  Max seq length: {config.model.max_seq_length}")
print(f"  Load in 4-bit: {config.model.load_in_4bit}")

print("\nLoRA Config")
print(f"  Rank (r): {config.lora.r}")
print(f"  Alpha: {config.lora.lora_alpha}")
print(f"  Dropout: {config.lora.lora_dropout}")

print("\nData Config")
print(f"  Dataset: {config.data.dataset_id}")
print(f"  Test split size: {config.data.test_split_size}")

print("\nTraining Config")
print(f"  Output dir: {config.training.output_dir}")
print(f"  Learning rate: {config.training.learning_rate}")
print(f"  Batch size: {config.training.per_device_train_batch_size}")
print(f"  Epochs: {config.training.num_train_epochs}")

Model Config
  Model: unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit
  Max seq length: 2048
  Load in 4-bit: True

LoRA Config
  Rank (r): 16
  Alpha: 32
  Dropout: 0.05

Data Config
  Dataset: ShenLab/MentalHealth16K
  Test split size: 1000

Training Config
  Output dir: data/llama-3.1-8b-instruct-sft-pilot
  Learning rate: 2e-05
  Batch size: 4
  Epochs: 1


## 2. Initialize WandB for Experiment Tracking

In [4]:
import wandb
wandb.login()

# Initialize run
wandb.init(
    entity="alha8035-stockholm-university",
    project="pilot_model0_sft",
    config=config.to_dict(),
    tags=["sft", "qlora", "unsloth"],
)

wandb: [wandb.login()] Loaded credentials for https://api.wandb.ai from /root/.netrc.
wandb: Currently logged in as: alha8035 (alha8035-stockholm-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


wandb: Detected [huggingface_hub.inference, openai] in use.
wandb: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
wandb: For more information, check out the docs at: https://weave-docs.wandb.ai/


## 3. Load Model & Tokenizer (with Unsloth)

In [5]:
model, tokenizer = get_model_and_tokenizer(
    model_name=config.model.model_name_or_path,
    max_seq_length=config.model.max_seq_length,
    load_in_4bit=config.model.load_in_4bit,
)

print(f"Model loaded: {config.model.model_name_or_path}")
print(f"Model dtype: {model.dtype}")
print(f"Tokenizer vocab size: {len(tokenizer)}")

==((====))==  Unsloth 2026.1.4: Fast Llama patching. Transformers: 4.57.6.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.10.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.6.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.34. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Model loaded: unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit
Model dtype: torch.float16
Tokenizer vocab size: 128256


In [6]:
#Apply PEFT/LoRA using Unsloth
model = apply_peft(
    model,
    r=config.lora.r,
    lora_alpha=config.lora.lora_alpha,
    lora_dropout=config.lora.lora_dropout,
    target_modules=config.lora.target_modules,
    bias=config.lora.bias,
    use_gradient_checkpointing=config.lora.use_gradient_checkpointing,
    random_state=config.lora.random_state,
)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2026.1.4 patched 32 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


Trainable parameters: 41,943,040 (0.92%)


## 4. Prepare Dataset

Load and preprocess the dataset following alignment-handbook's data pipeline.

In [8]:
dataset = load_and_split_dataset(
    dataset_id=config.data.dataset_id,
    dataset_config=config.data.dataset_config,
    dataset_split=config.data.dataset_split,
    test_split_size=config.data.test_split_size,
    seed=config.data.seed,
)

# Prepare dataset (format to messages, apply chat template)
dataset = prepare_dataset(dataset, tokenizer, num_proc=config.data.num_proc)

print(f"Train samples: {len(dataset['train'])}")
print(f"Test samples: {len(dataset.get('test', []))}")

DatasetNotFoundError: Dataset 'ShenLab/MentalHealth16K' doesn't exist on the Hub or cannot be accessed.

## 5. Train Model

Create trainer and run training following alignment-handbook's training loop.

In [None]:
training_args = create_training_args(
    output_dir=config.training.output_dir,
    learning_rate=config.training.learning_rate,
    per_device_train_batch_size=config.training.per_device_train_batch_size,
    gradient_accumulation_steps=config.training.gradient_accumulation_steps,
    num_train_epochs=config.training.num_train_epochs,
    max_seq_length=config.model.max_seq_length,
    eval_strategy=config.training.eval_strategy,
    eval_steps=config.training.eval_steps,
    save_steps=config.training.save_steps,
    logging_steps=config.training.logging_steps,
    warmup_ratio=config.training.warmup_ratio,
    weight_decay=config.training.weight_decay,
    lr_scheduler_type=config.training.lr_scheduler_type,
    optim=config.training.optim,
    bf16=config.training.bf16,
    gradient_checkpointing=config.training.gradient_checkpointing,
    save_total_limit=config.training.save_total_limit,
    seed=config.training.seed,
    report_to=config.training.report_to,
)

trainer = create_trainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    eval_dataset=dataset.get("test"),
    training_args=training_args,
)

print("Trainer created successfully!")

In [None]:
# Run Training
print("Starting training...")
train_result = trainer.train()

# Log final metrics
print(f"\n=== Training Complete ===")
print(f"Final train loss: {train_result.training_loss:.4f}")

# Evaluate if test set exists
if dataset.get("test") is not None:
    eval_metrics = trainer.evaluate()
    print(f"Eval loss: {eval_metrics['eval_loss']:.4f}")

## 6. Save Model

In [None]:
# Save Model and Tokenizer
OUTPUT_DIR = config.training.output_dir

trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"Model saved to {OUTPUT_DIR}")

# Finish WandB run
wandb.finish()

## 7. Hyperparameter Optimization with Optuna

In [None]:
import optuna
from optuna.visualization import plot_param_importances, plot_optimization_history
import plotly

# Reload config for HPO (starts fresh)
hpo_config = SFTScriptConfig.from_yaml(CONFIG_PATH)

# Run HPO study
study = run_hpo(
    config=hpo_config,
    n_trials=20,  # Adjust based on your compute budget
    study_name="pilot_sft_hpo",
)

# Display results
print(f"\n=== Best Hyperparameters ===")
for key, value in study.best_params.items():
    print(f"  {key}: {value}")
print(f"\nBest eval loss: {study.best_value:.4f}")

# Visualize
fig = plot_param_importances(study)
fig.show()
fig = plot_optimization_history(study)
fig.show()

## 8. Quick Inference Test

Test the trained model with a sample prompt.

In [None]:
from unsloth import FastLanguageModel

FastLanguageModel.for_inference(model)

# Test prompt
test_input = "I've been feeling really anxious lately about my job. I keep thinking I'm going to get fired even though there's no evidence of that."
system_prompt = "You are a helpful mental health counselling assistant, please answer the mental health questions based on the patient's description.  The assistant gives helpful, comprehensive, and appropriate answers to the user's questions."

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": test_input}
]

# Apply chat template
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
    )

response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
print(f"\nUser Input: {test_input}")
print(f"\nModel Response: {response}")