# LoRA Fine-Tuning for Text Classification on Apple M3

This notebook implements LoRA (Low-Rank Adaptation) fine-tuning for text classification using the AG News dataset, optimized specifically for Apple M3 chip
I'm having some issues with the huggingface trainer.

In [1]:
#!pip install transformers datasets evaluate scikit-learn peft accelerate torch matplotlib seaborn

Collecting matplotlib
  Downloading matplotlib-3.10.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting seaborn
  Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.1-cp312-cp312-macosx_11_0_arm64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.56.0-cp312-cp312-macosx_10_13_universal2.whl.metadata (101 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.8-cp312-cp312-macosx_11_0_arm64.whl.metadata (6.2 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Downloading pyparsing-3.2.1-py3-none-any.whl.metadata (5.0 kB)
Downloading matplotlib-3.10.1-cp312-cp312-macosx_11_0_arm64.whl (8.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.0/8.0 MB[0m [31m50.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownlo

## 2. Import Libraries

In [2]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model, TaskType
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


## 3. Set Up Device

Check if MPS (Metal Performance Shaders) is available on this Apple Silicon M3 device.

In [3]:
# Check for MPS availability on Apple Silicon
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# For M3 chips, verify the PyTorch version supports all M3 features
print(f"PyTorch version: {torch.__version__}")
if device.type == "mps":
    print("MPS device is available and will be used for training")

Using device: mps
PyTorch version: 2.5.1
MPS device is available and will be used for training


## 4. Load Dataset

We'll use the AG News dataset, which contains news articles classified into 4 categories: World, Sports, Business, and Sci/Tech.

In [4]:
# Load dataset
dataset = load_dataset("ag_news")
print(f"Dataset loaded with {len(dataset['train'])} training examples and {len(dataset['test'])} test examples")

# Examine a sample
print("\nSample training example:")
print(dataset['train'][0])

# Check class distribution
label_counts = {}
for example in dataset['train']:
    label = example['label']
    label_counts[label] = label_counts.get(label, 0) + 1

print("\nClass distribution in training set:")
for label, count in label_counts.items():
    print(f"Class {label}: {count} examples ({count/len(dataset['train'])*100:.2f}%)")

# Create class label mapping
label_mapping = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}

Dataset loaded with 120000 training examples and 7600 test examples

Sample training example:
{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}

Class distribution in training set:
Class 2: 30000 examples (25.00%)
Class 3: 30000 examples (25.00%)
Class 1: 30000 examples (25.00%)
Class 0: 30000 examples (25.00%)


## 5. Set Up Tokenizer

M3 should be fine with larger LLMs too

In [5]:
# model_name = "distilbert-base-uncased"  # Smaller, faster model
model_name = "bert-base-uncased"        
# model_name = "roberta-base"           

print(f"Using model: {model_name}")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define tokenization function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)

Using model: bert-base-uncased


## 6. Tokenize Datasets

In [6]:
# Tokenize datasets
print("Tokenizing training dataset...")
tokenized_train = dataset["train"].map(tokenize_function, batched=True)
print("Tokenizing testing dataset...")
tokenized_test = dataset["test"].map(tokenize_function, batched=True)

# Format datasets for PyTorch
tokenized_train.set_format("torch", columns=["input_ids", "attention_mask", "label"])
tokenized_test.set_format("torch", columns=["input_ids", "attention_mask", "label"])
print("Tokenization complete!")

Tokenizing training dataset...


Map: 100%|██████████| 120000/120000 [00:04<00:00, 24781.50 examples/s]


Tokenizing testing dataset...


Map: 100%|██████████| 7600/7600 [00:00<00:00, 19069.80 examples/s]

Tokenization complete!





## 7. Create DataLoaders with Batch Size Optimized for M3

In [7]:
# Create dataloaders with batch size optimized for M3
batch_size = 16  # M3 can handle larger batches than M1/M2
train_dataloader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(tokenized_test, batch_size=batch_size)

print(f"Created DataLoaders with batch size {batch_size}")
print(f"Training batches: {len(train_dataloader)}")
print(f"Evaluation batches: {len(eval_dataloader)}")

Created DataLoaders with batch size 16
Training batches: 7500
Evaluation batches: 475


## 8. Load Model and Apply LoRA

LoRA works by adding small "adapter" layers to the model, allowing us to fine-tune efficiently by updating only a small fraction of the parameters. We can use a higher rank value  for M3.

In [8]:
print(f"Loading {model_name}...")
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=4)

if "distilbert" in model_name:
    target_modules = ["q_lin", "k_lin", "v_lin", "out_lin"]
elif "bert" in model_name:
    target_modules = ["query", "key", "value", "output.dense"]
elif "roberta" in model_name:
    target_modules = ["query", "key", "value", "output.dense"]
else:
    target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]

config = LoraConfig(
    r=16,                           # Higher rank for M3 (more expressive adapters)
    lora_alpha=32,                  # Scaling factor
    target_modules=target_modules,  # Model-specific layers to adapt
    lora_dropout=0.1,               # Dropout probability for LoRA layers
    bias="none",                    # Don't train bias parameters
    task_type=TaskType.SEQ_CLS      # Sequence classification task
)

# Apply LoRA to model
model = get_peft_model(model, config)
print(model.print_trainable_parameters())

# Move model to device
model = model.to(device)

Loading bert-base-uncased...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'
trainable params: 1,920,004 || all params: 111,405,320 || trainable%: 1.7234
None


## 9. Set Up Optimizer and Learning Rate Scheduler

In [9]:
# Set up optimizer - higher learning rate works well with LoRA
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)  # Slightly higher learning rate for M3

# Set up learning rate scheduler
num_epochs = 4  # M3 can handle more epochs efficiently
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=2e-3, 
    total_steps=num_training_steps,
    pct_start=0.1  # Warm up for first 10% of training
)

# Loss function
loss_fn = torch.nn.CrossEntropyLoss()

## 10. Training Function with Progress Tracking

In [11]:
def train_epoch(model, dataloader, optimizer, lr_scheduler, epoch, num_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    
    for batch in progress_bar:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
        logits = outputs.logits
        
        # Compute loss
        loss = loss_fn(logits, batch["label"])
        epoch_loss += loss.item()
        
        # Backward pass
        loss.backward()
        
        # Update parameters
        optimizer.step()
        lr_scheduler.step()
        
        # Update progress bar
        progress_bar.set_postfix({"loss": loss.item()})
    
    return epoch_loss / len(dataloader)

## 11. Evaluation Function with Detailed Metrics

In [12]:
def evaluate(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    all_losses = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
            predictions = torch.argmax(outputs.logits, dim=-1)
            
            # Compute loss
            loss = loss_fn(outputs.logits, batch["label"])
            all_losses.append(loss.item())
            
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(batch["label"].cpu().numpy())
    
    # Calculate metrics
    metrics = {
        "accuracy": accuracy_score(all_labels, all_preds),
        "eval_loss": np.mean(all_losses)
    }
    
    return all_preds, all_labels, metrics

## 12. Training Loop with Visualization

In [None]:
print("Starting training...")
best_accuracy = 0
best_model_state = None

# Track metrics for plotting
train_losses = []
eval_losses = []
accuracies = []

for epoch in range(num_epochs):
    # Training phase
    avg_loss = train_epoch(model, train_dataloader, optimizer, lr_scheduler, epoch, num_epochs)
    train_losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {avg_loss:.4f}")
    
    # Evaluation phase
    predictions, labels, metrics = evaluate(model, eval_dataloader)
    accuracy = metrics["accuracy"]
    eval_loss = metrics["eval_loss"]
    
    eval_losses.append(eval_loss)
    accuracies.append(accuracy)
    
    print(f"Epoch {epoch+1}/{num_epochs} - Eval Loss: {eval_loss:.4f} - Accuracy: {accuracy:.4f}")
    
    # Save best model
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        best_model_state = model.state_dict().copy()
        print(f"New best model with accuracy: {best_accuracy:.4f}")

print("Training complete!")
print(f"Best accuracy: {best_accuracy:.4f}")

# Load best model if we saved one
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print("Loaded best model state")

# Plot training metrics
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(eval_losses, label='Evaluation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curves')

plt.subplot(1, 2, 2)
plt.plot(accuracies, label='Accuracy', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Progression')
plt.grid(True)

plt.tight_layout()
plt.show()

Starting training...


Epoch 1/4:  72%|███████▏  | 5427/7500 [1:07:01<21:36,  1.60it/s, loss=1.39]   

## 13. Comprehensive Evaluation

In [None]:
# Final evaluation with the best model
predictions, labels, _ = evaluate(model, eval_dataloader)

# Calculate additional metrics
from sklearn.metrics import precision_recall_fscore_support

# Get per-class metrics
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average=None)

# Create a summary table
import pandas as pd
metrics_df = pd.DataFrame({
    'Class': [label_mapping[i] for i in range(4)],
    'Precision': precision,
    'Recall': recall,
    'F1 Score': f1
})
print("Per-class Performance:")
display(metrics_df)

# Print classification report
print("\nClassification Report:")
target_names = [label_mapping[i] for i in range(4)]
print(classification_report(labels, predictions, target_names=target_names))

# Create confusion matrix
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_names, yticklabels=target_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

## 14. Save the Model

In [None]:
# Save the model
peft_model_id = f"m3-peft-{model_name.split('/')[-1]}-agnews"
model.save_pretrained(peft_model_id)
tokenizer.save_pretrained(peft_model_id)
print(f"Model saved to {peft_model_id}")

## 15. Interactive Testing

In [None]:
# Example inference function
def predict_text(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
    
    # Get probability distribution
    probs = probabilities[0].cpu().numpy()
    
    return predicted_class, label_mapping[predicted_class], probs

# Test with some examples
test_texts = [
    "NASA successfully launches new Mars rover with advanced sampling technology",
    "Liverpool FC wins dramatic match against Manchester United with last-minute goal",
    "Stock markets plunge amid concerns about inflation and interest rates",
    "New research shows promising results for quantum computing breakthroughs"
]

for text in test_texts:
    _, prediction, probs = predict_text(text)
    print(f"Text: {text}\nPredicted category: {prediction}")
    
    # Show probability distribution
    for i, p in enumerate(probs):
        print(f"  {label_mapping[i]}: {p:.4f} ({p*100:.1f}%)")
    print()

## 16. Try Your Own Text

In [None]:
# Input your own text for classification
your_text = input("Enter news text to classify: ")

_, prediction, probs = predict_text(your_text)
print(f"\nPredicted category: {prediction}")

# Visualize the probabilities
plt.figure(figsize=(10, 6))
plt.bar(target_names, probs)
plt.title('Category Probabilities')
plt.xlabel('Category')
plt.ylabel('Probability')
plt.ylim(0, 1)
for i, v in enumerate(probs):
    plt.text(i, v + 0.01, f'{v:.2f}', ha='center')
plt.show()

## 17. Loading the Model Later (Code Reference)

In [None]:
# This cell shows how to load and use the saved model later
'''
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Path to your saved model
peft_model_id = "m3-peft-bert-base-uncased-agnews"

# Load the configuration
config = PeftConfig.from_pretrained(peft_model_id)

# Load the base model
model = AutoModelForSequenceClassification.from_pretrained(
    config.base_model_name_or_path, 
    num_labels=4
)

# Load the PEFT adapter weights
model = PeftModel.from_pretrained(model, peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)

# Move to MPS device for M3
device = "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)

# Use for inference
def predict(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
    
    label_mapping = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
    return label_mapping[predicted_class]
'''