# Training Mathematics LLM with LoRA Fine-tuning

This notebook fine-tunes SmolLM2-1.7B on mathematical problems using LoRA adapters.

In [None]:
%%capture
!pip install unsloth
!pip install --upgrade --force-reinstall --no-deps "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install datasets transformers accelerate bitsandbytes torch sympy antlr4-python3-runtime wandb

In [None]:
import os
import re
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import sympy
import wandb
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application
from datasets import load_dataset, Dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    BitsAndBytesConfig
)
from huggingface_hub import notebook_login

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Configure sympy parsing
transformations = standard_transformations + (implicit_multiplication_application,)
x = sympy.Symbol('x')

In [None]:
# Login to HuggingFace
notebook_login()

In [None]:
# Model and training settings
MODEL = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
SAVED_MODEL = "Joash2024/Math-SmolLM2-1.7B"
max_seq_length = 2048
device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu'
logger.info(f"Using device: {device}")

# Training hyperparameters
HYPERPARAMS = {
    'num_train_epochs': 1,
    'per_device_train_batch_size': 4,
    'per_device_eval_batch_size': 4,
    'gradient_accumulation_steps': 4,
    'learning_rate': 2e-4,
    'weight_decay': 0.01,
    'warmup_ratio': 0.03,
    'max_grad_norm': 0.3,
    'lora_rank': 12,
    'lora_alpha': 16,
    'lora_dropout': 0.05
}

In [None]:
def find_matching_brace(expr: str, start: int) -> int:
    """Find the matching closing brace for an opening brace"""
    count = 1
    for i in range(start + 1, len(expr)):
        if expr[i] == '{':
            count += 1
        elif expr[i] == '}':
            count -= 1
            if count == 0:
                return i
    return -1

def replace_powers(expr: str) -> str:
    """Replace ^ with ** for Python power notation"""
    # Handle simple powers first
    expr = re.sub(r'\^([0-9]+)', r'**\1', expr)
    
    # Handle more complex powers with braces
    while '^{' in expr:
        pos = expr.find('^{')
        if pos == -1:
            break
            
        # Find the matching closing brace
        close_pos = find_matching_brace(expr, pos + 1)
        if close_pos == -1:
            break
            
        # Extract the exponent
        exponent = expr[pos+2:close_pos]
        
        # Replace the power notation
        expr = expr[:pos] + f'**({exponent})' + expr[close_pos+1:]
    
    return expr

def parse_latex_fraction(expr: str) -> str:
    """Convert LaTeX fractions to Python division"""
    while '\\frac' in expr:
        frac_pos = expr.find('\\frac')
        if frac_pos == -1:
            break
            
        # Find numerator
        num_start = expr.find('{', frac_pos)
        if num_start == -1:
            break
        num_end = find_matching_brace(expr, num_start)
        if num_end == -1:
            break
        numerator = expr[num_start+1:num_end]
        
        # Find denominator
        den_start = expr.find('{', num_end)
        if den_start == -1:
            break
        den_end = find_matching_brace(expr, den_start)
        if den_end == -1:
            break
        denominator = expr[den_start+1:den_end]
        
        # Replace fraction
        replacement = f"(({numerator})/({denominator}))"
        expr = expr[:frac_pos] + replacement + expr[den_end+1:]
    
    return expr

def parse_function_power(expr: str, func_name: str, power: str, argument: str) -> str:
    """Handle function with power (e.g., sin^2(x))"""
    # Special case for e^x
    if func_name == 'e':
        return f"exp({argument})"
    
    # For other functions
    return f"({func_name}({argument}))**({power})"

def parse_latex_functions(expr: str) -> str:
    """Convert LaTeX math functions to Python/SymPy functions"""
    # Function mapping
    func_map = {
        '\\sin': 'sin',
        '\\cos': 'cos',
        '\\tan': 'tan',
        '\\exp': 'exp',
        '\\log': 'log',
        '\\sqrt': 'sqrt',
        '\\pi': 'pi'
    }
    
    # Replace functions
    for latex_func, sympy_func in func_map.items():
        expr = expr.replace(latex_func, sympy_func)
    
    # Handle function powers
    power_func_pattern = r'([a-z]+)(\^{[^{}]+})?\{([^{}]+)\}'
    while True:
        match = re.search(power_func_pattern, expr)
        if not match:
            break
            
        func_name = match.group(1)
        power = match.group(2)
        argument = match.group(3)
        
        if power:
            # Extract power value from ^{...}
            power = power[2:-1]
            replacement = parse_function_power(expr, func_name, power, argument)
        else:
            replacement = f"{func_name}({argument})"
        
        expr = expr[:match.start()] + replacement + expr[match.end():]
    
    return expr

def clean_latex(expr: str, debug: bool = True) -> Tuple[str, str, List[str]]:
    """Clean LaTeX expression with detailed transformation steps"""
    steps = [f"Original: {expr}"]
    
    # Store original
    original = expr
    cleaned = expr
    
    # Replace powers first
    cleaned = replace_powers(cleaned)
    steps.append(f"After power replacement: {cleaned}")
    
    # Handle fractions
    cleaned = parse_latex_fraction(cleaned)
    steps.append(f"After fraction parsing: {cleaned}")
    
    # Handle functions
    cleaned = parse_latex_functions(cleaned)
    steps.append(f"After function parsing: {cleaned}")
    
    # Clean up LaTeX artifacts
    cleaned = cleaned.replace('\\left', '').replace('\\right', '')
    cleaned = cleaned.replace('\\{', '(').replace('\\}', ')')
    cleaned = cleaned.replace('{', '(').replace('}', ')')
    steps.append(f"After cleanup: {cleaned}")
    
    # Add multiplication symbols where needed
    cleaned = re.sub(r'([0-9])([a-zA-Z])', r'\1*\2', cleaned)
    cleaned = re.sub(r'\)\(', r')*(', cleaned)
    cleaned = re.sub(r'([0-9a-zA-Z])\s+([a-zA-Z]\()', r'\1*\2', cleaned)
    steps.append(f"Final: {cleaned}")
    
    return original, cleaned, steps

def validate_math_expression(expr: str, debug: bool = True) -> bool:
    """Validate if a string is a valid mathematical expression"""
    try:
        if not expr or not isinstance(expr, str):
            if debug:
                print("Invalid input: empty or not a string")
            return False
        
        original, cleaned, steps = clean_latex(expr, debug)
        
        if debug:
            print("\nTransformation steps:")
            for step in steps:
                print(step)
        
        # Parse with sympy using transformations
        parsed = parse_expr(cleaned, transformations=transformations, local_dict={'x': x})
        
        if debug:
            print(f"Successfully parsed as: {parsed}")
        return True
        
    except Exception as e:
        if debug:
            print(f"Failed to parse expression:")
            print(f"Error type: {type(e).__name__}")
            print(f"Error message: {str(e)}")
        return False

def is_valid_derivative_problem(example: Dict, debug: bool = False) -> bool:
    """Check if an example is a valid derivative problem"""
    if not isinstance(example.get('Function'), str) or not isinstance(example.get('Derivative'), str):
        if debug:
            print(f"Invalid types in example: {example}")
        return False
    
    func = example['Function'].strip()
    deriv = example['Derivative'].strip()
    
    if not func or not deriv:
        if debug:
            print(f"Empty function or derivative in example: {example}")
        return False
    
    if debug:
        print(f"\nChecking derivative problem:")
        print(f"Function: {func}")
        print(f"Derivative: {deriv}")
    
    func_valid = validate_math_expression(func, debug)
    deriv_valid = validate_math_expression(deriv, debug)
    
    return func_valid and deriv_valid

In [None]:
# Test expressions
test_expressions = [
    "x^2 + 3x",  # Simple polynomial
    "\\frac{1}{2}x^2",  # Fraction with power
    "\\sin{\\left(x^2\\right)}",  # Trigonometric function
    "\\frac{5 x e^{- 4 x}}{8}",  # Complex fraction with e
    "- \\sin{\\left(7 x^{3} \\right)}",  # Negative trig function
    "\\cos^{2}{\\left(x\\right)}",  # Function with power
    "e^{x} \\sin{\\left(x^2\\right)}",  # Multiple functions
    "\\frac{\\sin{\\left(x\\right)}}{\\cos{\\left(x\\right)}}",  # Fraction of functions
    "2x + \\frac{1}{x}",  # Mixed terms
    "x^{2} e^{x} \\sin{\\left(x\\right)}"  # Multiple terms
]

print("Testing expression validation:")
for expr in test_expressions:
    print(f"\nTesting: {expr}")
    validate_math_expression(expr, debug=True)

In [None]:
# Configure quantization
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_quant_type="nf8",
    bnb_8bit_compute_dtype=torch.float16,
    bnb_8bit_use_double_quant=True,
)

# Load tokenizer and model
logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.pad_token = tokenizer.eos_token

logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16
)

In [None]:
# Load datasets
logger.info("Loading datasets...")
train_data = load_dataset("Alexis-Az/math_datasets", name='derivatives', split='train[:-2000]').shuffle()
eval_data = load_dataset("Alexis-Az/math_datasets", 'derivatives', split="train[-2000:]").shuffle()

logger.info(f"Training examples: {len(train_data)}")
logger.info(f"Validation examples: {len(eval_data)}")

# Display sample data
print("\nSample training data:")
print(train_data[0])
print("\nDataset columns:")
print(train_data.column_names)

In [None]:
def format_math_problem(function: str, derivative: str) -> str:
    """Format a math problem for the model"""
    return f"""Given a mathematical function, find its derivative.

Function: {function}
The derivative of this function is: {derivative}

Let's verify this step by step:
1. Starting with f(x) = {function}
2. Applying differentiation rules
3. We get f'(x) = {derivative}
""".strip()

def prepare_training_data(examples: Dict) -> Dict:
    """Convert examples to model inputs with proper tensor handling"""
    # Format all examples in batch
    texts = [
        format_math_problem(function=func, derivative=deriv)
        for func, deriv in zip(examples['Function'], examples['Derivative'])
    ]
    
    # Tokenize with padding and truncation
    tokenized = tokenizer(
        texts,
        padding='max_length',
        truncation=True,
        max_length=max_seq_length,
        return_tensors='pt'  # Return PyTorch tensors directly
    )
    
    # Set up labels (same as input_ids for causal LM)
    tokenized['labels'] = tokenized['input_ids'].clone()
    
    return tokenized

# Process a small batch first to verify format
logger.info("Testing data processing on a small batch...")
test_batch = train_data.select(range(4))
processed_batch = prepare_training_data(test_batch)

print("\nVerifying test batch:")
for key, value in processed_batch.items():
    print(f"{key}: {type(value)}, shape: {value.shape}, dtype: {value.dtype}")

# Process full datasets
logger.info("\nProcessing full datasets...")
train_dataset = train_data.map(
    prepare_training_data,
    batched=True,
    batch_size=100,
    remove_columns=train_data.column_names,
    desc="Processing training data"
)

eval_dataset = eval_data.map(
    prepare_training_data,
    batched=True,
    batch_size=100,
    remove_columns=eval_data.column_names,
    desc="Processing validation data"
)

# Verify final dataset
print("\nVerifying final dataset:")
print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(eval_dataset)}")

sample = train_dataset[0]
print("\nSample features:")
for key, value in sample.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: tensor shape {value.shape}, dtype {value.dtype}")
    else:
        print(f"{key}: type {type(value)}")

# Decode a sample to verify content
print("\nSample decoded text:")
decoded = tokenizer.decode(sample['input_ids'])
print(decoded)

In [None]:
# Configure LoRA
lora_config = LoraConfig(
    r=HYPERPARAMS['lora_rank'],
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=HYPERPARAMS['lora_alpha'],
    lora_dropout=HYPERPARAMS['lora_dropout'],
    bias="none",
    task_type="CAUSAL_LM"
)

logger.info("Preparing model for training...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
# Initialize wandb for tracking
wandb.init(
    project="math-llm",
    config=HYPERPARAMS,
    name=f"lora-r{HYPERPARAMS['lora_rank']}-lr{HYPERPARAMS['learning_rate']}"
)

In [None]:
# Training configuration
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=HYPERPARAMS['num_train_epochs'],
    per_device_train_batch_size=HYPERPARAMS['per_device_train_batch_size'],
    per_device_eval_batch_size=HYPERPARAMS['per_device_eval_batch_size'],
    gradient_accumulation_steps=HYPERPARAMS['gradient_accumulation_steps'],
    eval_steps=100,
    logging_steps=50,
    learning_rate=HYPERPARAMS['learning_rate'],
    weight_decay=HYPERPARAMS['weight_decay'],
    fp16=True,
    bf16=False,
    max_grad_norm=HYPERPARAMS['max_grad_norm'],
    max_steps=-1,
    warmup_ratio=HYPERPARAMS['warmup_ratio'],
    group_by_length=True,
    lr_scheduler_type="cosine",
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=100,
    save_total_limit=3,
    load_best_model_at_end=True,
    push_to_hub=True,
    hub_model_id=SAVED_MODEL,
    hub_strategy="every_save",
    gradient_checkpointing=True,
    report_to=["wandb"],
    remove_unused_columns=False
)

In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

# Train the model
logger.info("Starting training...")
train_result = trainer.train()

# Log training results
logger.info(f"\nTraining results:")
for key, value in train_result.metrics.items():
    logger.info(f"{key}: {value}")

In [None]:
# Save the model
logger.info("\nSaving model to HuggingFace Hub...")
trainer.push_to_hub()

logger.info(f"\nTraining complete! Model saved to: {SAVED_MODEL}")
logger.info(f"You can access your model at: https://huggingface.co/{SAVED_MODEL}")

# Close wandb run
wandb.finish()