# SFT Training Pipeline
In addition to the HuggingFace/Alignment Handbook, the following "tools" were used: 
- **Unsloth** for faster training with less memory
- **QLoRA** for parameter-efficient fine-tuning
- **Optuna** for hyperparameter optimization 
- **WandB** for experiment tracking and visualization

## Pipeline Structure
1. Setup & Configuration
2. Load Model & Tokenizer (Unsloth)
3. Prepare Dataset
4. Train Model
5. Hyperparameter Search
6. Save & Test Model

In [None]:
# Cell 1: Setup and Imports
import logging
import os
import sys
import torch

# Configure logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
    handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)

# Import local modules
from src import (
    SFTScriptConfig,
    get_model_and_tokenizer,
    apply_peft,
    load_and_split_dataset,
    prepare_dataset,
    create_training_args,
    create_trainer,
    train,
    run_hpo,
    prepare_for_inference,
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Configuration
The config follows alignment-handbook's structure with sections for: model, lora, data, training.

In [None]:
# Cell 2: Load Configuration from YAML
CONFIG_PATH = "recipes/SFT/config_pilot.yaml"
config = SFTScriptConfig.from_yaml(CONFIG_PATH)

# Display configuration
print("=== Model Config ===")
print(f"  Model: {config.model.model_name_or_path}")
print(f"  Max seq length: {config.model.max_seq_length}")
print(f"  Load in 4-bit: {config.model.load_in_4bit}")

print("\n=== LoRA Config ===")
print(f"  Rank (r): {config.lora.r}")
print(f"  Alpha: {config.lora.lora_alpha}")
print(f"  Dropout: {config.lora.lora_dropout}")

print("\n=== Data Config ===")
print(f"  Dataset: {config.data.dataset_id}")
print(f"  Test split size: {config.data.test_split_size}")

print("\n=== Training Config ===")
print(f"  Output dir: {config.training.output_dir}")
print(f"  Learning rate: {config.training.learning_rate}")
print(f"  Batch size: {config.training.per_device_train_batch_size}")
print(f"  Epochs: {config.training.num_train_epochs}")

## 2. Initialize WandB for Experiment Tracking

In [None]:
# Cell 3: Initialize WandB
import wandb
wandb.login()

# Initialize run
wandb.init(
    entity="alha8035-stockholm-university",
    project="pilot_model0_sft",
    config=config.to_dict(),
    tags=["sft", "qlora", "unsloth"],
)

In [None]:
## 3. Load Model & Tokenizer (with Unsloth)

In [None]:
# Cell 4: Load Model and Tokenizer with Unsloth
model, tokenizer = get_model_and_tokenizer(
    model_name=config.model.model_name_or_path,
    max_seq_length=config.model.max_seq_length,
    load_in_4bit=config.model.load_in_4bit,
)

print(f"Model loaded: {config.model.model_name_or_path}")
print(f"Model dtype: {model.dtype}")
print(f"Tokenizer vocab size: {len(tokenizer)}")

In [None]:
# Cell 5: Apply PEFT/LoRA using Unsloth's optimized implementation
model = apply_peft(
    model,
    r=config.lora.r,
    lora_alpha=config.lora.lora_alpha,
    lora_dropout=config.lora.lora_dropout,
    target_modules=config.lora.target_modules,
    bias=config.lora.bias,
    use_gradient_checkpointing=config.lora.use_gradient_checkpointing,
    random_state=config.lora.random_state,
)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")

## 4. Prepare Dataset

Load and preprocess the dataset following alignment-handbook's data pipeline.

In [None]:
# Cell 6: Load and Prepare Dataset
dataset = load_and_split_dataset(
    dataset_id=config.data.dataset_id,
    dataset_config=config.data.dataset_config,
    dataset_split=config.data.dataset_split,
    test_split_size=config.data.test_split_size,
    seed=config.data.seed,
)

# Prepare dataset (format to messages, apply chat template)
dataset = prepare_dataset(dataset, tokenizer, num_proc=config.data.num_proc)

print(f"Train samples: {len(dataset['train'])}")
print(f"Test samples: {len(dataset.get('test', []))}")
print(f"\nSample text preview (truncated):")
print(dataset['train'][0]['text'][:500] + "...")

In [None]:
## 5. Train Model

Create trainer and run training following alignment-handbook's training loop.

# Cell 7: Create Training Arguments and Trainer
training_args = create_training_args(
    output_dir=config.training.output_dir,
    learning_rate=config.training.learning_rate,
    per_device_train_batch_size=config.training.per_device_train_batch_size,
    gradient_accumulation_steps=config.training.gradient_accumulation_steps,
    num_train_epochs=config.training.num_train_epochs,
    max_seq_length=config.model.max_seq_length,
    eval_strategy=config.training.eval_strategy,
    eval_steps=config.training.eval_steps,
    save_steps=config.training.save_steps,
    logging_steps=config.training.logging_steps,
    warmup_ratio=config.training.warmup_ratio,
    weight_decay=config.training.weight_decay,
    lr_scheduler_type=config.training.lr_scheduler_type,
    optim=config.training.optim,
    bf16=config.training.bf16,
    gradient_checkpointing=config.training.gradient_checkpointing,
    save_total_limit=config.training.save_total_limit,
    seed=config.training.seed,
    report_to=config.training.report_to,
)

trainer = create_trainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset["train"],
    eval_dataset=dataset.get("test"),
    training_args=training_args,
)

print("Trainer created successfully!")

In [None]:
# Cell 8: Run Training
print("Starting training...")
train_result = trainer.train()

# Log final metrics
print(f"\n=== Training Complete ===")
print(f"Final train loss: {train_result.training_loss:.4f}")

# Evaluate if test set exists
if dataset.get("test") is not None:
    eval_metrics = trainer.evaluate()
    print(f"Eval loss: {eval_metrics['eval_loss']:.4f}")

## 6. Save Model

In [None]:
# Cell 9: Save Model and Tokenizer
OUTPUT_DIR = config.training.output_dir

trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"Model saved to {OUTPUT_DIR}")

# Finish WandB run
wandb.finish()

## 7. (Optional) Hyperparameter Optimization with Optuna

Optuna provides smarter Bayesian optimization with trial pruning. Combined with WandB logging, you get:
- **Optuna**: Efficient search, early pruning of bad trials
- **WandB**: Visualization, comparison, collaboration

This is generally superior to WandB Sweeps alone for finding optimal hyperparameters.

In [None]:
# Cell 10: Run Optuna HPO (Optional - uncomment to run)
# This will search for optimal hyperparameters across multiple trials

# import optuna
# from optuna.visualization import plot_param_importances, plot_optimization_history

# # Reload config for HPO (starts fresh)
# hpo_config = SFTScriptConfig.from_yaml(CONFIG_PATH)

# # Run HPO study
# study = run_hpo(
#     config=hpo_config,
#     n_trials=20,  # Adjust based on your compute budget
#     study_name="pilot_sft_hpo",
# )

# # Display results
# print(f"\n=== Best Hyperparameters ===")
# for key, value in study.best_params.items():
#     print(f"  {key}: {value}")
# print(f"\nBest eval loss: {study.best_value:.4f}")

# # Visualize (requires plotly)
# # fig = plot_param_importances(study)
# # fig.show()
# # fig = plot_optimization_history(study)
# # fig.show()

## 8. Quick Inference Test

Test the trained model with a sample prompt.

In [None]:
# Cell 11: Quick Inference Test
from unsloth import FastLanguageModel

# Prepare model for inference (2x faster)
FastLanguageModel.for_inference(model)

# Test prompt
test_input = "I've been feeling really anxious lately about my job. I keep thinking I'm going to get fired even though there's no evidence of that."
system_prompt = "You are a helpful mental health counselling assistant. Please provide supportive and appropriate responses to the user's concerns."

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": test_input}
]

# Apply chat template
prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)

# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        pad_token_id=tokenizer.pad_token_id,
    )

# Decode and display
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
print("=== User Input ===")
print(test_input)
print("\n=== Model Response ===")
print(response)