# CAI Training Pipeline 🚀🤖

This Colab notebook has the functionality to run the entire constitutional training setup.

**Constitutional AI (CAI)** is a concept introduced by Anthropic in their paper. It is a method aimed at aligning AI systems with human values and ethical principles, particularly harmlessnes. CAI involves training AI models to follow a set of predefined rules or "constitution" that guides their behavior. This approach is particularly useful for practical settings where ensuring the AI's alignment with human values is crucial.

## Prerequisites 📋✅

In [None]:
!git clone https://github.com/MarinaFuster/cai-implementation
%cd cai-implementation

In [None]:
%pip install -r requirements.txt

In [None]:
import logging

# Configure root logger to display logs in Colab
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO, 
    force=True
)

In [None]:
from dotenv import load_dotenv
# Load the .env file
load_dotenv()

In [None]:
import sys, os
# this is required for the code to be able to import the modules
sys.path.append(os.path.abspath("."))

## Creating Managers 🛠️👨‍💼

In this section, we will initialize the `DatasetManager`, which is responsible for creating the datasets for both the supervised fine tuning and direct preference optimization stages.

In [None]:
from src import DatasetManager
dataset_manager = DatasetManager()

## Supervised Finetuning Stage 🎯📚

In this stage, we fine-tune the pre-trained model using a labeled dataset. The goal is to improve the model's performance on specific tasks by providing it with examples of the correct output for given inputs. This process helps the model learn to make more accurate predictions and better align with the desired outcomes.

In [None]:
import os
from pathlib import Path
sft_output_dir = Path(os.getenv('SFT_OUTPUT_DIR'))
sft_output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
from src import ModelManager
model_manager = ModelManager(model_name="mistralai/Mistral-7B-Instruct-v0.3")

In [None]:
sft_dataset = dataset_manager.get_sft_train_dataset(n_samples_harmless=2000, n_samples_helpful=800)
tokenized_sft_dataset = sft_dataset.map(model_manager.tokenize_function, batched=True)
tokenized_sft_dataset = tokenized_sft_dataset.remove_columns(["input_text", "output_text"])
tokenized_sft_dataset.set_format("torch")

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir=sft_output_dir,
    per_device_train_batch_size=1,  # Keep batch size low to fit in memory
    gradient_accumulation_steps=4,  # Accumulate gradients over multiple steps
    num_train_epochs=3,  # Adjust based on dataset size
    save_steps=500,
    logging_steps=100,
    save_total_limit=2,
    learning_rate=2e-4,  # Adjust based on performance
    fp16=True,  # Enable mixed precision for speed
    optim="paged_adamw_8bit",  # More memory-efficient optimizer
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    report_to="none"  # Disable wandb integration
)

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=16, 
    lora_alpha=32, 
    target_modules=["q_proj", "v_proj"],  # Apply LoRA to key layers
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

base_model = get_peft_model(model_manager.model, lora_config)
base_model.print_trainable_parameters()

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=base_model,
    args=training_args,
    train_dataset=tokenized_sft_dataset,
)

trainer.train()
base_model.save_pretrained(sft_output_dir)
model_manager.tokenizer.save_pretrained(sft_output_dir)

## Direct Preference Optimization Stage 🎯🔍

In this stage, we optimize the model based on direct user preferences. The goal is to align the model's behavior with the preferences and values of the users by using feedback directly from them. This process involves collecting user feedback on the model's outputs and using this information to adjust the model's parameters, ensuring that it produces results that are more in line with what users want and expect. This stage is crucial for creating AI systems that are not only accurate but also user-friendly and aligned with human values.

In [None]:
import os
from pathlib import Path
dpo_output_dir = Path(os.getenv('DPO_OUTPUT_DIR'))
dpo_output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
from src import ModelManager
sft_model_manager = ModelManager(model_dir=sft_output_dir)

In [None]:
dpo_dataset = dataset_manager.get_prefs_train_dataset(n_samples_harmless=2000, n_samples_helpful=800)
tokenized_dpo_dataset = dpo_dataset.map(sft_model_manager.tokenize_function, batched=True)
tokenized_dpo_dataset = tokenized_dpo_dataset.remove_columns(["input_text", "output_text"])
tokenized_dpo_dataset.set_format("torch")

In [None]:
from trl import DPOTrainer, DPOConfig

dpo_training_args = DPOConfig(
    output_dir=dpo_output_dir, 
    logging_steps=10
)

# Initialize the DPOTrainer using the fine-tuned model as the starting point
dpo_trainer = DPOTrainer(
    model=sft_model_manager.model,  
    ref_model=None,
    args=dpo_training_args,  
    train_dataset=dpo_dataset,  # Your prepared DPO dataset
)

In [None]:
dpo_trainer.train()
sft_model_manager.model.save_pretrained(dpo_output_dir)
sft_model_manager.tokenizer.save_pretrained(dpo_output_dir)