# DistilBERT + LoRA

Fine-tune a DistilBERT classifier with LoRA (Low-Rank Adaptation) for toxic comment detection.
Same pipeline as the full fine-tuning baseline, but only a fraction of parameters are trained.

## 1. Setup

In [None]:
import pandas as pd
import numpy as np
import pickle
import json
import sys
import torch
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

DATA_DIR = Path("../../data")
MODEL_DIR = DATA_DIR / "models/transformer_lora"
MODEL_DIR.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(Path("../.." ).resolve()))

## 2. Load Data

In [None]:
df = pd.read_csv(DATA_DIR / "train_dataset.csv")
print(f"Total samples: {len(df)}")
print(f"Toxicity rate: {df['y'].mean():.2%}")

## 3. Train/Test Split

In [None]:
X = df['text'].values
y = df['y'].values

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=RANDOM_SEED, stratify=y
)

print(f"Train: {len(X_train)}, Test: {len(X_test)}")

## 4. Fine-tune with LoRA

LoRA (Low-Rank Adaptation) freezes the pretrained weights and injects trainable
rank-decomposition matrices into the attention layers, drastically reducing
the number of trainable parameters while preserving quality.

In [None]:
from model import LoRATransformerClassifier
from tqdm.auto import tqdm

clf = LoRATransformerClassifier(
    model_name="distilbert-base-uncased",
    max_length=128,
    batch_size=64,
    lora_r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_lin", "v_lin"],
)

clf.fit(
    X_train.tolist(), y_train,
    epochs=3, batch_size=64, lr=2e-5
)

## 5. Evaluate

In [None]:
y_pred = clf.predict(X_test.tolist())
print(classification_report(y_test, y_pred, target_names=['Safe', 'Toxic']))

cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix (LoRA)')
plt.show()

## 6. Save Model

In [None]:
clf.save(str(MODEL_DIR))

with open(MODEL_DIR / "test_data.pkl", 'wb') as f:
    pickle.dump({'X_test': X_test, 'y_test': y_test}, f)

print("Models saved!")

## 7. Load Model Using Our Classes

In [None]:
import importlib
if 'model' in sys.modules:
    importlib.reload(sys.modules['model.models'])
    importlib.reload(sys.modules['model'])

from model import LoRATransformerClassifier

loaded_model = LoRATransformerClassifier(model_dir=str(MODEL_DIR))

test_texts = ["You are stupid", "Have a nice day"]
predictions = loaded_model.predict(test_texts)
probas = loaded_model.predict_proba(test_texts)

print("Test predictions:")
for text, pred, proba in zip(test_texts, predictions, probas):
    print(f"  '{text}' -> {pred} (prob: {proba[1]:.3f})")

## 8. Comprehensive Evaluation with get_metrics()

In [None]:
eval_size = 1000
X_test_subset = X_test[:eval_size].tolist()
y_test_subset = y_test[:eval_size]

results = loaded_model.get_metrics(
    X_test_subset,
    y_test_subset,
    n_latency_runs=50
)

print("=" * 60)
print("EVALUATION RESULTS (LoRA)")
print("=" * 60)

print("\nQuality Metrics:")
for metric, value in results['quality'].items():
    print(f"  {metric:15s}: {value:.4f}")

print(f"\nConfusion Matrix:")
cm_array = np.array(results['confusion_matrix'])
print(f"  [[TN={cm_array[0,0]:5d}, FP={cm_array[0,1]:5d}]")
print(f"   [FN={cm_array[1,0]:5d}, TP={cm_array[1,1]:5d}]]")

print(f"\nLatency Metrics:")
for metric, value in results['latency'].items():
    print(f"  {metric:20s}: {value:.4f}")

print(f"\nThroughput: {results['throughput_samples_per_sec']:.2f} samples/sec")
print(f"Peak Memory: {results['peak_memory_mb']:.2f} MB")

## 9. Visualize Performance Metrics

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

quality_data = results['quality']
axes[0].bar(quality_data.keys(), quality_data.values(), color='steelblue')
axes[0].set_title('Quality Metrics (LoRA)', fontweight='bold')
axes[0].set_ylim(0, 1)
axes[0].tick_params(axis='x', rotation=45)
for i, (k, v) in enumerate(quality_data.items()):
    axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=9)

cm_array = np.array(results['confusion_matrix'])
sns.heatmap(cm_array, annot=True, fmt='d', cmap='Blues', ax=axes[1],
            xticklabels=['Safe', 'Toxic'], yticklabels=['Safe', 'Toxic'])
axes[1].set_title('Confusion Matrix (LoRA)', fontweight='bold')
axes[1].set_ylabel('True Label')
axes[1].set_xlabel('Predicted Label')

latency_data = results['latency']
latency_values = [latency_data['latency_mean_ms'], latency_data['latency_std_ms'],
                  latency_data['latency_min_ms'], latency_data['latency_max_ms']]
latency_labels = ['Mean', 'Std', 'Min', 'Max']
axes[2].bar(latency_labels, latency_values, color='coral')
axes[2].set_title('Latency (ms per sample)', fontweight='bold')
axes[2].set_ylabel('Milliseconds')
for i, v in enumerate(latency_values):
    axes[2].text(i, v + 0.01, f'{v:.2f}', ha='center', fontsize=9)

plt.tight_layout()
plt.show()

print(f"\nModel Summary (LoRA):")
print(f"  Throughput: {results['throughput_samples_per_sec']:.0f} samples/sec")
print(f"  Avg Latency: {latency_data['latency_mean_ms']:.2f} ms/sample")
print(f"  F1 Score: {quality_data['f1_score']:.4f}")

## 10. Save Evaluation Results

In [None]:
results_path = MODEL_DIR / "evaluation_results.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Evaluation results saved to {results_path}")