In [None]:
from unsloth import FastLanguageModel
import torch

In [None]:
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",      # Llama-3.1 15 trillion tokens model 2x faster!
    "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
    "unsloth/Meta-Llama-3.1-405B-bnb-4bit",    # We also uploaded 4bit for 405b!
    "unsloth/Mistral-Nemo-Base-2407-bnb-4bit", # New Mistral 12b 2x faster!
    "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
    "unsloth/mistral-7b-v0.3-bnb-4bit",        # Mistral v3 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/Phi-3.5-mini-instruct",           # Phi-3.5 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/gemma-2-9b-bnb-4bit",
    "unsloth/gemma-2-27b-bnb-4bit",            # Gemma 2x faster!
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

In [None]:
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer

In [None]:
from datasets import load_from_disk


dataset = load_from_disk("path")   # folder that contains data-00000-of-00001.arrow etc.


print(f"Loaded {len(dataset)} examples")
print("Columns:", dataset.column_names)      # should be only: input_ids, attention_mask (and maybe labels if you added them)
print("Features:", dataset.features)

In [None]:
print("Sample decoded:")
print(tokenizer.decode(dataset[0]["input_ids"])[:800] + "...")

In [None]:
class EvidentialTrainer(SFTTrainer):
   
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        # Double check logic to prevent silent failures
        if logits is None:
            raise ValueError("Logits are None! Ensure os.environ['UNSLOTH_RETURN_LOGITS']='1' is set at the top of the script.")

        labels = inputs.get("labels")

        # Shift logits and labels for Next-Token Prediction
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # --- DIRICHLET LOGIC START ---
        # 1. Convert Logits to Evidence (must be non-negative)
        evidence = F.softplus(shift_logits)
        
        # 2. Calculate Dirichlet Parameters (Alpha)
        alpha = evidence + 1.0
        
        # 3. Calculate Dirichlet Strength (Total Evidence)
        S = torch.sum(alpha, dim=-1, keepdim=True)
        
        # 4. Calculate Expected Probability: E[p] = alpha / S
        expected_probs = alpha / S
        
        # 5. Loss: Negative Log Likelihood of the Expected Probability
        # We use log(expected_probs) to be compatible with nll_loss
        log_probs = torch.log(expected_probs + 1e-8) 

        # Flatten for NLL Loss
        loss = F.nll_loss(
            log_probs.view(-1, log_probs.size(-1)), 
            shift_labels.view(-1)
        )
        # --- DIRICHLET LOGIC END ---

        return (loss, outputs) if return_outputs else loss
        

In [None]:
trainer = EvidentialTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,                  # already tokenized â†’ super fast
    dataset_text_field=None,                # IMPORTANT: None because we have input_ids already
    max_seq_length=max_seq_length,          # can even be None now
    packing=False,                          # must be False (we already did the packing/splitting)
    args=SFTConfig(
        per_device_train_batch_size=2,     
        gradient_accumulation_steps=4,
        warmup_steps=5,
        max_steps=60,                       
        learning_rate=2e-4,
        fp16=False,
        bf16=True,                          
        logging_steps=1,
        optim="adamw_8bit",                
        weight_decay=0.001,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        report_to="none",
        dataloader_num_workers=4,          
        remove_unused_columns=True,
        dataloader_pin_memory=True,
        
    ),
)

print("Starting training from pre-processed arrow file...")
trainer.train()


In [None]:
# 1. Save the adapters first just in case
trainer.save_model("temp_adapters")
from peft import PeftModel

merged_model = model.merge_and_unload()

# 3. Save the full merged model
save_path = "evidential_model_merged"
merged_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

print(f"Full merged model saved to: {save_path}")

In [None]:
from peft import OFTConfig, TaskType, get_peft_model

oft_config = OFTConfig(
    r=8,                           # Rank (similar to LoRA)
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    task_type=TaskType.CAUSAL_LM,
    module_dropout=0.05,           # Helps prevent overfitting
    coft=True,                     # "Constrained" OFT - strictly enforces orthogonality
    eps=6e-5,                      
)

model = get_peft_model(model, oft_config)
model.print_trainable_parameters()

In [None]:
from transformers import Trainer
import torch

class EvidentialOrthogonalTrainer(EvidentialTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # 1. Run the standard forward pass
        outputs = model(**inputs)
        
        # 2. Get the standard loss (NLL or Evidential Loss)
        loss = outputs.loss if isinstance(outputs, dict) else outputs[0]
      
        hidden_states = outputs.hidden_states[-1] # (Batch, Seq, Dim)
        
        # Normalize vectors
        norm_hidden = torch.nn.functional.normalize(hidden_states, p=2, dim=-1)
        
        # Calculate cosine similarity matrix
        # (Batch, Seq, Dim) x (Batch, Dim, Seq) -> (Batch, Seq, Seq)
        similarity = torch.bmm(norm_hidden, norm_hidden.transpose(1, 2))
        
        # We want the off-diagonal elements (token-to-token similarity) to be low (orthogonal)
        # to prevent "mode collapse" or repetitive loops (hallucination)
        identity = torch.eye(similarity.size(1)).to(similarity.device)
        ortho_loss = torch.norm(similarity - identity, p='fro')
        
        # Add to total loss with a small coefficient (lambda)
        total_loss = loss + (0.1 * ortho_loss)
        
        return (total_loss, outputs) if return_outputs else total_loss

In [None]:
from transformers import TrainingArguments
from trl import SFTConfig # or use TrainingArguments directly

args = SFTConfig(
    output_dir="models", # New output folder for the 2nd round
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    max_steps=60,                     # Train for 60 MORE steps
    learning_rate=5e-5,               # Lower LR is better for 2nd round (was 2e-4)
    fp16=False,
    bf16=True,
    logging_steps=1,
    optim="adamw_8bit",
    save_strategy="steps",            # Save checkpoints this time
    save_steps=20,
    report_to="none",
)

# Re-initialize your custom Trainer
trainer = EvidentialTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,            # You can use the same dataset or a new one
    dataset_text_field=None,
    max_seq_length=max_seq_length,
    packing=False,
    args=args
)

print("Resuming training on the saved Orthogonal model...")
trainer.train()

In [None]:

save_path = "evidential_model_orthogonal"

trainer.save_model(save_path)

tokenizer.save_pretrained(save_path)

trainer.model.config.save_pretrained(save_path)

print(f"Orthogonal adapters saved to: {save_path}")