# Training Walkthrough: Fine-tuning Mistral 7B for Financial Crime Detection

This notebook provides a step-by-step guide to training the FinCrime-LLM model using QLoRA.

## Contents
1. Environment Setup
2. Data Preparation
3. Model Configuration
4. LoRA Setup
5. Training
6. Monitoring & Evaluation
7. Model Saving

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 1. Environment Setup

In [None]:
import os
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path().absolute().parent
sys.path.insert(0, str(PROJECT_ROOT))

# Set environment variables
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'

print("✅ Environment configured")

In [None]:
# Import training modules
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_from_disk
import wandb

print("✅ Libraries imported")

## 2. Load Training Data

In [None]:
# Load prepared dataset
data_path = PROJECT_ROOT / "data" / "processed" / "sar_dataset_alpaca"

if data_path.exists():
    dataset = load_from_disk(str(data_path))
    print(f"✅ Dataset loaded: {dataset}")
    print(f"\nTrain examples: {len(dataset['train'])}")
    print(f"Validation examples: {len(dataset['validation'])}")
    print(f"Test examples: {len(dataset['test'])}")
    
    # Show example
    print("\n" + "="*80)
    print("Sample Training Example:")
    print("="*80)
    example = dataset['train'][0]
    print(f"Instruction: {example['instruction'][:200]}...")
    print(f"\nInput: {example['input'][:200]}...")
    print(f"\nOutput: {example['output'][:200]}...")
else:
    print(f"❌ Dataset not found at {data_path}")
    print("Please run: python data/scripts/prepare_sar_data.py")

## 3. Model Configuration

In [None]:
# Model configuration
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
OUTPUT_DIR = PROJECT_ROOT / "models" / "sar-mistral-7b-demo"

# Quantization config for 4-bit training
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

print(f"Model: {MODEL_NAME}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Quantization: 4-bit NF4")

In [None]:
# Load base model
print("Loading base model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    use_cache=False,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
)

# Set padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

model.config.pad_token_id = tokenizer.pad_token_id

print(f"✅ Model loaded. Total parameters: {model.num_parameters():,}")

## 4. LoRA Configuration

In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# LoRA configuration
lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,  # Scaling factor
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Apply LoRA
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\n✅ LoRA applied")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
print(f"Total parameters: {total_params:,}")

## 5. Training Configuration

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=str(OUTPUT_DIR),
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    evaluation_strategy="steps",
    eval_steps=100,
    save_total_limit=3,
    load_best_model_at_end=True,
    bf16=True,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    report_to="none",  # Set to "wandb" if using W&B
)

print("Training configuration:")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")

In [None]:
# Initialize trainer
def formatting_func(example):
    """Format examples for training."""
    return f"""### Instruction:
{example['instruction']}

### Input:
{example['input']}

### Response:
{example['output']}"""

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    tokenizer=tokenizer,
    max_seq_length=2048,
    formatting_func=formatting_func,
)

print("✅ Trainer initialized")

## 6. Start Training

**Note**: Training can take several hours depending on GPU and dataset size.

In [None]:
# Start training
print("\n" + "="*80)
print("Starting training...")
print("="*80 + "\n")

trainer.train()

print("\n" + "="*80)
print("✅ Training complete!")
print("="*80)

## 7. Save Model

In [None]:
# Save final model
final_output_dir = OUTPUT_DIR / "final"
trainer.save_model(str(final_output_dir))
tokenizer.save_pretrained(str(final_output_dir))

print(f"✅ Model saved to {final_output_dir}")
print("\nModel can be loaded with:")
print(f"  from inference.load_model import load_fincrime_model")
print(f"  model, tokenizer = load_fincrime_model('{final_output_dir}')")

## 8. Quick Test

In [None]:
# Test the model
test_prompt = """### Instruction:
Generate a suspicious activity report based on the following transaction details:

### Input:
Country: Ghana
Subject: Test Company Ltd
Total Amount: 100,000 GHS
Transactions: Multiple cash deposits under threshold

### Response:
"""

inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        do_sample=True,
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\n" + "="*80)
print("Test Generation:")
print("="*80)
print(generated_text[len(test_prompt):])

## Next Steps

1. Proceed to notebook 03 for comprehensive model evaluation
2. Test on held-out test set
3. Calculate ROUGE, BLEU, and other metrics
4. Deploy the model via API