# Qwen Model Training with MLX Framework

**Purpose**: Fine-tune Qwen 2.5 14B model using MLX framework for logical fallacy detection

**Dataset**: FLICC (Fallacy detection dataset with train/val/test splits)

**Method**: LoRA (Low-Rank Adaptation) fine-tuning with Technocognitive Adaptation

---

## Configuration Overview

- **Model**: Qwen/Qwen2.5-14B-Instruct
- **LoRA Rank**: 16 (higher rank to compensate for LoRA weakness)
- **LoRA Alpha**: 32 (2x rank for stability)
- **Batch Size**: 1 (optimized for 32GB RAM)
- **Gradient Accumulation**: 16 (effective batch size of 16)
- **Learning Rate**: 1e-5 (paper-optimal)
- **Iterations**: 1000 (~6 epochs over 2500 samples)

## Step 1: Import Required Libraries

In [None]:
# Core libraries
import json
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Optional

# MLX framework
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx_lm import load, generate
from mlx_lm import lora

# Data processing
import pandas as pd
import numpy as np

# Utilities
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix

print("All libraries imported successfully!")

## Step 2: Training Configuration

### Technocognitive Adaptation Parameters

Based on research findings:
- **Rank 16**: Higher rank to compensate for LoRA weakness noted in paper
- **Alpha 32**: Standard stability rule (Alpha = 2 × Rank)
- **Learning Rate 1e-5**: Paper-validated optimal rate
- **Scale 10.0**: Enhanced adaptation scaling
- **Comprehensive Layer Coverage**: All projection layers for maximum adaptation

In [None]:
@dataclass
class TrainingConfig:
    """
    Training configuration based on Technocognitive Adaptation research.
    Optimized for 32GB RAM with Qwen 2.5 14B model.
    """
    
    # Base Model
    model_name: str = "Qwen/Qwen2.5-14B-Instruct"
    
    # LoRA Configuration (Technocognitive Adaptation)
    lora_rank: int = 16              # Higher rank to compensate for LoRA weakness
    lora_alpha: int = 32             # Alpha = 2 * Rank (stability rule)
    lora_dropout: float = 0.05       # Standard dropout for regularization
    lora_scale: float = 10.0         # Enhanced adaptation scaling
    
    # LoRA target layers - comprehensive coverage
    lora_layers: List[str] = None    # Will be set in __post_init__
    
    # Training Hyperparameters (32GB RAM optimized)
    batch_size: int = 1                      # Required for 32GB RAM
    gradient_accumulation_steps: int = 16    # Effective batch size = 16
    learning_rate: float = 1.0e-5            # Paper-validated optimal rate
    weight_decay: float = 0.01               # Weight decay for regularization
    
    # Training Duration
    # Dataset ~2500 rows, 1000 iterations with batch 16 ≈ 6 epochs
    iters: int = 1000                # Total training iterations
    steps_per_eval: int = 100        # Validation frequency
    save_every: int = 100            # Checkpoint frequency
    steps_per_report: int = 10       # Loss reporting frequency
    
    # Data Paths
    train_data_path: str = "Data/fallacy_train.csv"
    val_data_path: str = "Data/fallacy_val.csv"
    test_data_path: str = "Data/fallacy_test.csv"
    
    # Output Configuration
    output_dir: str = "./output/qwen14b_flicc"
    data_dir: str = "./data"
    
    # Training Options
    max_seq_length: int = 512        # Maximum sequence length
    seed: int = 42                   # Random seed for reproducibility
    grad_checkpoint: bool = True     # Gradient checkpointing (saves memory)
    val_batches: int = 25            # Validation batches per evaluation
    
    def __post_init__(self):
        """Initialize LoRA layers after dataclass creation."""
        if self.lora_layers is None:
            # Comprehensive layer coverage as specified
            self.lora_layers = [
                "q_proj",      # Query projection
                "v_proj",      # Value projection
                "k_proj",      # Key projection
                "o_proj",      # Output projection
                "gate_proj",   # Gate projection (MLP)
                "up_proj",     # Up projection (MLP)
                "down_proj"    # Down projection (MLP)
            ]


# Initialize configuration
config = TrainingConfig()

# Display configuration
print("="*70)
print("TRAINING CONFIGURATION (Technocognitive Adaptation)")
print("="*70)
print(f"\nModel: {config.model_name}")

print(f"\nLoRA Settings (Technocognitive Adaptation):")
print(f"  Rank: {config.lora_rank} (higher to compensate for LoRA weakness)")
print(f"  Alpha: {config.lora_alpha} (2 × Rank stability rule)")
print(f"  Dropout: {config.lora_dropout}")
print(f"  Scale: {config.lora_scale}")
print(f"  Target Layers: {', '.join(config.lora_layers)}")

print(f"\nTraining Settings (32GB RAM Optimized):")
print(f"  Batch Size: {config.batch_size}")
print(f"  Gradient Accumulation: {config.gradient_accumulation_steps}")
print(f"  Effective Batch Size: {config.batch_size * config.gradient_accumulation_steps}")
print(f"  Learning Rate: {config.learning_rate} (paper-optimal)")
print(f"  Weight Decay: {config.weight_decay}")

print(f"\nTraining Duration:")
print(f"  Total Iterations: {config.iters}")
print(f"  Approximate Epochs: ~6 (over ~2500 samples)")
print(f"  Validation Every: {config.steps_per_eval} steps")
print(f"  Save Checkpoint Every: {config.save_every} steps")

print(f"\nData Splits:")
print(f"  Training: {config.train_data_path}")
print(f"  Validation: {config.val_data_path}")
print(f"  Test: {config.test_data_path}")

print(f"\nOutput Directory: {config.output_dir}")
print("="*70)

## Step 3: Data Loading and Preparation

Load all three data splits: train, validation, and test.

In [None]:
def load_fallacy_data(file_path: str, split_name: str = "data") -> pd.DataFrame:
    """
    Load fallacy detection dataset from CSV file.
    
    Args:
        file_path: Path to the CSV file
        split_name: Name of the split (train/val/test) for display
        
    Returns:
        DataFrame with 'text' and 'label' columns
    """
    print(f"\nLoading {split_name} data from: {file_path}")
    
    try:
        df = pd.read_csv(file_path)
        print(f"Loaded {len(df):,} examples")
        
        # Display label distribution
        print(f"\nLabel Distribution ({split_name}):")
        label_counts = df['label'].value_counts().sort_index()
        for label, count in label_counts.items():
            print(f"  {label:25s}: {count:4d} ({count/len(df)*100:.1f}%)")
        
        return df
        
    except FileNotFoundError:
        print(f"ERROR: File not found at {file_path}")
        raise
    except Exception as e:
        print(f"ERROR loading data: {e}")
        raise


# Load all three data splits
print("\n" + "="*70)
print("LOADING ALL DATA SPLITS")
print("="*70)

train_df = load_fallacy_data(config.train_data_path, "TRAIN")
val_df = load_fallacy_data(config.val_data_path, "VALIDATION")
test_df = load_fallacy_data(config.test_data_path, "TEST")

print(f"\n" + "="*70)
print("DATA LOADING SUMMARY")
print("="*70)
print(f"Training samples:   {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")
print(f"Test samples:       {len(test_df):,}")
print(f"Total samples:      {len(train_df) + len(val_df) + len(test_df):,}")
print("="*70)

## Step 4: Prompt Engineering

Create prompts in Qwen's chat format for instruction fine-tuning.

**Format**:
```
<|im_start|>system
[System instructions]
<|im_end|>
<|im_start|>user
[User query]
<|im_end|>
<|im_start|>assistant
[Model response]
<|im_end|>
```

In [None]:
# Define fallacy categories
FALLACY_CATEGORIES = [
    "ad hominem",
    "anecdote",
    "cherry picking",
    "conspiracy theory",
    "fake experts",
    "false choice",
    "false equivalence",
    "impossible expectations",
    "misrepresentation",
    "oversimplification",
    "single cause",
    "slothful induction"
]

def prepare_prompt(text: str, label: Optional[str] = None) -> str:
    """
    Prepare prompt in Qwen chat format.
    
    Args:
        text: Input text to classify
        label: Ground truth label (for training) or None (for inference)
        
    Returns:
        Formatted prompt string
    """
    # System instruction
    system_msg = (
        "You are a logical fallacy detection expert. "
        "Classify the given text into one of these fallacy categories: "
        f"{', '.join(FALLACY_CATEGORIES)}. "
        "Respond with only the fallacy category name."
    )
    
    # User query
    user_msg = f"Text: {text}\n\nWhat type of fallacy is present in this text?"
    
    # Build prompt
    if label is not None:
        # Training format (includes the answer)
        prompt = (
            f"<|im_start|>system\n{system_msg}<|im_end|>\n"
            f"<|im_start|>user\n{user_msg}<|im_end|>\n"
            f"<|im_start|>assistant\n{label}<|im_end|>"
        )
    else:
        # Inference format (no answer)
        prompt = (
            f"<|im_start|>system\n{system_msg}<|im_end|>\n"
            f"<|im_start|>user\n{user_msg}<|im_end|>\n"
            f"<|im_start|>assistant\n"
        )
    
    return prompt


def prepare_dataset(df: pd.DataFrame, split_name: str = "data") -> List[Dict[str, str]]:
    """
    Convert DataFrame to MLX training format (JSONL).
    
    Args:
        df: DataFrame with 'text' and 'label' columns
        split_name: Name of the split for display
        
    Returns:
        List of dictionaries with 'text' key containing formatted prompts
    """
    print(f"\nPreparing {split_name} dataset: {len(df):,} examples...")
    
    dataset = []
    for idx, row in df.iterrows():
        prompt = prepare_prompt(row['text'], row['label'])
        dataset.append({"text": prompt})
        
        # Show progress every 500 examples
        if (idx + 1) % 500 == 0:
            print(f"  Processed {idx + 1:,}/{len(df):,} examples...")
    
    print(f"Preparation complete for {split_name}!")
    return dataset


# Example prompt display
print("\n" + "="*70)
print("EXAMPLE PROMPT FORMAT")
print("="*70)
sample_text = train_df.iloc[0]['text']
sample_label = train_df.iloc[0]['label']
example_prompt = prepare_prompt(sample_text, sample_label)
print(example_prompt[:500] + "..." if len(example_prompt) > 500 else example_prompt)
print("="*70)

# Prepare all datasets
print("\n" + "="*70)
print("PREPARING DATASETS")
print("="*70)

train_dataset = prepare_dataset(train_df, "TRAIN")
val_dataset = prepare_dataset(val_df, "VALIDATION")
test_dataset = prepare_dataset(test_df, "TEST")

print(f"\n" + "="*70)
print("DATASET PREPARATION SUMMARY")
print("="*70)
print(f"Training examples:   {len(train_dataset):,}")
print(f"Validation examples: {len(val_dataset):,}")
print(f"Test examples:       {len(test_dataset):,}")
print("="*70)

## Step 5: Save Datasets in JSONL Format

MLX requires data in JSONL (JSON Lines) format. We'll save train and validation sets.
The test set will be used later for final evaluation.

In [None]:
# Create data directory
data_dir = Path(config.data_dir)
data_dir.mkdir(exist_ok=True)

print("\nSaving datasets to JSONL format...")

# Save training data
train_path = data_dir / "train.jsonl"
with open(train_path, "w", encoding="utf-8") as f:
    for item in train_dataset:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"Saved training data to: {train_path}")

# Save validation data
val_path = data_dir / "valid.jsonl"
with open(val_path, "w", encoding="utf-8") as f:
    for item in val_dataset:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")
print(f"Saved validation data to: {val_path}")

# Note: Test set is kept separate for final evaluation after training
print(f"\nNote: Test set ({len(test_dataset):,} examples) reserved for final evaluation")
print("\nAll datasets saved successfully!")

## Step 6: Setup Training Arguments

Configure MLX-LM trainer with Technocognitive Adaptation parameters.

In [None]:
# Prepare training arguments with Technocognitive Adaptation settings
train_args = argparse.Namespace(
    # Model and data
    model=config.model_name,
    data=str(data_dir),
    train=True,
    
    # LoRA configuration (Technocognitive Adaptation)
    lora_layers=16,  # Number of layers to apply LoRA (will use the keys defined)
    
    # Training hyperparameters
    batch_size=config.batch_size,
    iters=config.iters,
    val_batches=config.val_batches,
    learning_rate=config.learning_rate,
    
    # Evaluation and reporting
    steps_per_report=config.steps_per_report,
    steps_per_eval=config.steps_per_eval,
    save_every=config.save_every,
    
    # Output and checkpointing
    adapter_path=config.output_dir,
    
    # Advanced options
    grad_checkpoint=config.grad_checkpoint,
    seed=config.seed,
    use_dora=False,
    resume_adapter_file=None,
    
    # Sequence length
    max_seq_length=config.max_seq_length,
    
    # Test settings (disabled during training)
    test=False,
    test_batches=0
)

# Display training plan
print("\n" + "="*70)
print("TRAINING PLAN (Technocognitive Adaptation)")
print("="*70)
print(f"\nData:")
print(f"  Training examples: {len(train_dataset):,}")
print(f"  Validation examples: {len(val_dataset):,}")

print(f"\nBatching:")
print(f"  Batch size: {config.batch_size}")
print(f"  Gradient accumulation: {config.gradient_accumulation_steps}")
print(f"  Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")

print(f"\nTraining Schedule:")
print(f"  Total iterations: {config.iters:,}")
print(f"  Approximate epochs: ~6")
print(f"  Samples per iteration: {config.batch_size * config.gradient_accumulation_steps}")
print(f"  Total samples processed: ~{config.iters * config.batch_size * config.gradient_accumulation_steps:,}")

print(f"\nEvaluation & Checkpointing:")
print(f"  Validate every: {config.steps_per_eval} steps")
print(f"  Save checkpoint every: {config.save_every} steps")
print(f"  Report metrics every: {config.steps_per_report} steps")
print(f"  Total checkpoints: ~{config.iters // config.save_every}")

print(f"\nOptimization:")
print(f"  Learning rate: {config.learning_rate} (paper-optimal)")
print(f"  Weight decay: {config.weight_decay}")
print(f"  LoRA rank: {config.lora_rank}")
print(f"  LoRA alpha: {config.lora_alpha}")
print(f"  LoRA scale: {config.lora_scale}")

print("="*70)

## Step 7: Start Training

**Important**: Training the 14B model will take several hours.

**Expected Duration**: 1000 iterations with validation and checkpointing may take 4-8 hours depending on hardware.

Monitor the output for:
- Training loss (should decrease steadily)
- Validation loss (should decrease, watch for overfitting)
- Checkpoints (saved every 100 steps)

In [None]:
# Create output directory
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

print("\n" + "="*70)
print("STARTING TRAINING")
print("="*70)
print(f"\nModel: {config.model_name}")
print(f"Output: {config.output_dir}")
print(f"\nExpected duration: 4-8 hours (depends on hardware)")
print(f"\nThe model will be saved at:")
print(f"  - Every {config.save_every} steps")
print(f"  - After training completes")
print("\n" + "="*70 + "\n")

# Load base model
print("Loading base model (this may take a few minutes)...")
model, tokenizer = load(config.model_name)
print("Model loaded successfully!\n")

# Start training
print("Training in progress...")
print("Monitor the logs below for training progress.\n")
print("-" * 70)

lora.train(
    model=model,
    tokenizer=tokenizer,
    args=train_args
)

print("\n" + "-" * 70)
print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)
print(f"\nModel adapter saved to: {config.output_dir}")
print(f"\nYou can now:")
print(f"  1. Load the fine-tuned model for inference")
print(f"  2. Evaluate on the test set")
print(f"  3. Deploy the model")
print("="*70)

## Step 8: Load the Fine-Tuned Model

In [None]:
print("\nLoading fine-tuned model with trained adapter...")

# Load model with LoRA adapter
model, tokenizer = load(
    config.model_name,
    adapter_path=config.output_dir
)

print("Model loaded successfully!")
print(f"Using adapter from: {config.output_dir}")

## Step 9: Quick Test with Sample Examples

In [None]:
# Test with sample examples
test_examples = [
    {
        "text": "Scientists say climate change is fake because they're all funded by the government.",
        "expected": "ad hominem"
    },
    {
        "text": "Vaccines are dangerous because my friend's cousin got sick after getting vaccinated.",
        "expected": "anecdote"
    },
    {
        "text": "The economy is doing great because the stock market is up.",
        "expected": "cherry picking"
    }
]

print("\n" + "="*70)
print("QUICK MODEL TEST")
print("="*70 + "\n")

for i, example in enumerate(test_examples, 1):
    print(f"Test {i}/{len(test_examples)}")
    print("-" * 70)
    print(f"Input: {example['text']}")
    print(f"Expected: {example['expected']}")
    
    # Generate prediction
    prompt = prepare_prompt(example['text'])
    response = generate(
        model,
        tokenizer,
        prompt=prompt,
        max_tokens=20,
        temp=0.1
    )
    
    print(f"Predicted: {response.strip()}")
    print()

print("="*70)

## Step 10: Validation Set Evaluation

In [None]:
def evaluate_model(
    model, 
    tokenizer, 
    df: pd.DataFrame,
    split_name: str = "Validation",
    num_samples: Optional[int] = None
) -> tuple:
    """
    Evaluate model on a dataset.
    
    Args:
        model: Trained model
        tokenizer: Model tokenizer
        df: DataFrame to evaluate on
        split_name: Name of the split for display
        num_samples: Number of samples to evaluate (None = all)
        
    Returns:
        Tuple of (predictions, true_labels)
    """
    if num_samples is not None:
        sample_df = df.sample(min(num_samples, len(df)), random_state=42)
        print(f"\nEvaluating on {len(sample_df)} {split_name} samples...\n")
    else:
        sample_df = df
        print(f"\nEvaluating on all {len(sample_df)} {split_name} samples...\n")
    
    predictions = []
    true_labels = []
    
    # Evaluate with progress bar
    for idx, row in tqdm(
        sample_df.iterrows(), 
        total=len(sample_df),
        desc=f"Evaluating {split_name}"
    ):
        prompt = prepare_prompt(row['text'])
        
        response = generate(
            model,
            tokenizer,
            prompt=prompt,
            max_tokens=20,
            temp=0.1
        )
        
        # Extract predicted label
        pred_label = response.strip().split('\n')[0].strip().lower()
        true_label = row['label'].lower()
        
        predictions.append(pred_label)
        true_labels.append(true_label)
    
    return predictions, true_labels


# Evaluate on validation set
print("\n" + "="*70)
print("VALIDATION SET EVALUATION")
print("="*70)

val_predictions, val_true_labels = evaluate_model(
    model, 
    tokenizer, 
    val_df,
    split_name="Validation",
    num_samples=None  # Evaluate on all validation samples
)

# Calculate metrics
val_accuracy = accuracy_score(val_true_labels, val_predictions)

print("\n" + "="*70)
print("VALIDATION RESULTS")
print("="*70)
print(f"\nAccuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")
print("\nDetailed Classification Report:")
print(classification_report(val_true_labels, val_predictions))
print("="*70)

## Step 11: Test Set Evaluation (Final Performance)

Evaluate on the held-out test set for final performance metrics.

In [None]:
# Evaluate on test set
print("\n" + "="*70)
print("TEST SET EVALUATION (Final Performance)")
print("="*70)

test_predictions, test_true_labels = evaluate_model(
    model, 
    tokenizer, 
    test_df,
    split_name="Test",
    num_samples=None  # Evaluate on all test samples
)

# Calculate metrics
test_accuracy = accuracy_score(test_true_labels, test_predictions)

print("\n" + "="*70)
print("FINAL TEST RESULTS")
print("="*70)
print(f"\nTest Accuracy: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)")
print("\nDetailed Classification Report:")
print(classification_report(test_true_labels, test_predictions))
print("\nConfusion Matrix:")
print(confusion_matrix(test_true_labels, test_predictions))
print("="*70)

## Step 12: Save Complete Training Report

In [None]:
# Prepare comprehensive training report
training_report = {
    "model_configuration": {
        "model_name": config.model_name,
        "lora_parameters": {
            "rank": config.lora_rank,
            "alpha": config.lora_alpha,
            "dropout": config.lora_dropout,
            "scale": config.lora_scale,
            "target_layers": config.lora_layers
        }
    },
    "training_parameters": {
        "batch_size": config.batch_size,
        "gradient_accumulation_steps": config.gradient_accumulation_steps,
        "effective_batch_size": config.batch_size * config.gradient_accumulation_steps,
        "learning_rate": config.learning_rate,
        "weight_decay": config.weight_decay,
        "total_iterations": config.iters,
        "max_seq_length": config.max_seq_length,
        "seed": config.seed
    },
    "data_splits": {
        "train_samples": len(train_dataset),
        "validation_samples": len(val_dataset),
        "test_samples": len(test_dataset)
    },
    "evaluation_results": {
        "validation_accuracy": float(val_accuracy),
        "test_accuracy": float(test_accuracy),
        "num_validation_evaluated": len(val_predictions),
        "num_test_evaluated": len(test_predictions)
    },
    "fallacy_categories": FALLACY_CATEGORIES
}

# Save training report
report_path = Path(config.output_dir) / "training_report.json"
with open(report_path, "w") as f:
    json.dump(training_report, f, indent=2)

print("\nTraining report saved!")
print(f"Location: {report_path}")

# Display summary
print("\n" + "="*70)
print("TRAINING SUMMARY")
print("="*70)
print(json.dumps(training_report, indent=2))
print("="*70)

## Training Complete!

### Summary

Your Qwen 2.5 14B model has been successfully fine-tuned for fallacy detection using Technocognitive Adaptation parameters.

### Output Files

Located in `./output/qwen14b_flicc/`:
- `adapters.safetensors` - Trained LoRA adapter weights
- `adapter_config.json` - LoRA configuration
- `training_report.json` - Complete training report with metrics

### Performance Metrics

- **Validation Accuracy**: See Step 10 results
- **Test Accuracy**: See Step 11 results (final unbiased performance)

### Usage Example

```python
from mlx_lm import load, generate

# Load fine-tuned model
model, tokenizer = load(
    "Qwen/Qwen2.5-14B-Instruct",
    adapter_path="./output/qwen14b_flicc"
)

# Prepare prompt
prompt = prepare_prompt("Your text here")

# Generate prediction
response = generate(model, tokenizer, prompt=prompt, max_tokens=20, temp=0.1)
print(f"Predicted fallacy: {response.strip()}")
```

### Next Steps

1. Analyze confusion matrix to identify which fallacies are confused
2. Review misclassified examples for insights
3. Consider additional fine-tuning if needed
4. Deploy model for production use