# DistilBERT Sequence Classification Trainer

This notebook demonstrates how to fine-tune a DistilBERT model for sequence classification using the Hugging Face Transformers library.


In [None]:
# Install accelerate package - required for Trainer
%pip install "accelerate>=0.26.0" --quiet


## Prepare Dependencies

In [None]:
# Import required libraries
import json 
import torch 
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import os
import numpy as np 
import evaluate 
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments

# Disable wandb logging
os.environ["WANDB_DISABLED"] = "true"


## Prepare model and data

In [None]:
#Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert/distilbert-base-uncased", 
    num_labels=2
)

# tokenizer = AutoTokenizer.from_pretrained("roberta-base")
# model = AutoModelForSequenceClassification.from_pretrained(
#     "roberta-base", 
#     num_labels=2
# )
print(model)


In [None]:
import random 

random.seed(2025)

# Load dataset
with open("../data/classification/incorrect/train_v1.json", "r", encoding="utf-8") as f:
    data = json.load(f)

with open("../data/classification/incorrect/train_v2.json", "r", encoding="utf-8") as f:
    data_v2 = json.load(f)

# Prepare dataset
dataset = {
    "label": [],
    "text": []
}
for dialog in data: 
    for i, sentence in enumerate(dialog): 
        if i == 0: 
            dataset["label"].append(1)
            dataset["text"].append(sentence)
        else: 
            if sentence == None : 
                continue 
            else: 
                dataset["label"].append(0)
                dataset["text"].append(sentence)

for dialog in data_v2: 
    for i, sentence in enumerate(dialog): 
        if i == 0: 
            dataset["label"].append(1)
            dataset["text"].append(sentence)
        else: 
            if sentence == None : 
                continue 
            else: 
                dataset["label"].append(0)
                dataset["text"].append(sentence)

# Convert to Hugging Face Dataset
dataset = Dataset.from_dict(dataset)

# Split dataset into train and test
train_size = int(0.9 * len(dataset))
train_dataset = dataset.select(range(train_size))
test_dataset = dataset.select(range(train_size, len(dataset)))

# Print dataset sizes
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Print ratio classification
print(f"Ratio classification: {train_dataset['label'].count(1) / len(train_dataset)}")





## Training phase

In [None]:
# Tokenize dataset
def tokenize_function(examples): 
    tokens = tokenizer(examples["text"], padding="max_length", truncation=True)
    tokens["labels"] = examples["label"]
    return tokens

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)


In [6]:
# Define metrics and compute_metrics function
metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_pred): 
    logits, labels = eval_pred 
    predictions = np.argmax(logits, axis=-1)
    accuracy = metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    # Return a dictionary with all metrics
    return {
        "accuracy": accuracy["accuracy"] if accuracy is not None else None,
        "f1": f1["f1"] if f1 is not None else None
    }


In [7]:
# Set up training arguments
training_args = TrainingArguments(
    output_dir="../results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay = 0.1,

)


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [8]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)


In [9]:
# Train the model
trainer.train()


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.5025,0.421037,0.780756,0.595691
2,0.41,0.40865,0.778007,0.619552
3,0.3755,0.408358,0.789003,0.640094
4,0.3179,0.444314,0.779381,0.637288
5,0.2832,0.467105,0.779381,0.631458


TrainOutput(global_step=4090, training_loss=0.37012760423506386, metrics={'train_runtime': 364.7584, 'train_samples_per_second': 179.393, 'train_steps_per_second': 11.213, 'total_flos': 8668004231055360.0, 'train_loss': 0.37012760423506386, 'epoch': 5.0})

In [10]:
# Evaluate the model
evaluation_results = trainer.evaluate()
print(evaluation_results)


{'eval_loss': 0.4671052396297455, 'eval_accuracy': 0.7793814432989691, 'eval_f1': 0.6314580941446614, 'eval_runtime': 2.4071, 'eval_samples_per_second': 604.451, 'eval_steps_per_second': 37.804, 'epoch': 5.0}


## Testing phase

In [13]:
import json
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
import torch
import torch.nn.functional as F

# Load test data
test_file_path = "../data/classification/incorrect/train_v2.json"
with open(test_file_path, "r", encoding="utf-8") as f:
    test_data = json.load(f)

print(f"Loaded test data with {len(test_data)} dialogs")

# Process test data - create labels based on position (first sentence = correct=1, others = incorrect=0)
test_sentences = []
test_labels = []

for dialog in test_data:
    for i, sentence in enumerate(dialog):
        if sentence is None:
            continue
        
        # Handle different data structures
        if isinstance(sentence, list):
            # If it's a list of tokens, join them
            if all(isinstance(token, list) for token in sentence):
                # Token-level structure [[word, tag, ...], ...]
                text = " ".join([token[0] for token in sentence if len(token) > 0])
            else:
                # Simple list of words
                text = " ".join(sentence)
        elif isinstance(sentence, str):
            text = sentence
        elif isinstance(sentence, dict) and "text" in sentence:
            text = sentence["text"]
        else:
            continue
            
        if text.strip():  # Only add non-empty sentences
            test_sentences.append(text.strip())
            # First sentence in dialog = correct (label 1), others = incorrect (label 0)
            test_labels.append(1 if i == 0 else 0)

print(f"Processed {len(test_sentences)} test sentences")
print(f"Label distribution - Correct (1): {test_labels.count(1)}, Incorrect (0): {test_labels.count(0)}")

# Create test dataset
test_dataset = Dataset.from_dict({
    "text": test_sentences,
    "label": test_labels
})

# Tokenize test dataset
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)
print(f"Test dataset: {tokenized_test_dataset}")

# Run predictions
print("\n" + "="*50)
print("RUNNING MODEL PREDICTIONS...")
print("="*50)

predictions = trainer.predict(tokenized_test_dataset)
y_pred_logits = predictions.predictions
y_true = test_labels

# Convert logits to probabilities and predicted classes
y_pred_probs = F.softmax(torch.tensor(y_pred_logits), dim=-1).numpy()
y_pred_classes = np.argmax(y_pred_logits, axis=-1)

print(f"Predictions shape: {y_pred_logits.shape}")
print(f"Probabilities shape: {y_pred_probs.shape}")
print(f"Predicted classes shape: {y_pred_classes.shape}")


Loaded test data with 1465 dialogs
Processed 5782 test sentences
Label distribution - Correct (1): 1465, Incorrect (0): 4317


Map: 100%|██████████| 5782/5782 [00:00<00:00, 16295.60 examples/s]

Test dataset: Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 5782
})

RUNNING MODEL PREDICTIONS...





Predictions shape: (5782, 2)
Probabilities shape: (5782, 2)
Predicted classes shape: (5782,)


In [17]:
# ========================================
# COMPREHENSIVE MODEL EVALUATION
# ========================================

# Calculate comprehensive metrics
accuracy = accuracy_score(y_true, y_pred_classes)
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred_classes, average='weighted')

print("="*60)
print("OVERALL MODEL PERFORMANCE")
print("="*60)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")

print("\n" + "="*60)
print("DETAILED CLASSIFICATION REPORT")
print("="*60)
print(classification_report(y_true, y_pred_classes, target_names=['Incorrect (0)', 'Correct (1)']))

print("\n" + "="*60)
print("CONFUSION MATRIX")
print("="*60)
cm = confusion_matrix(y_true, y_pred_classes)
print("           Predicted")
print("         0 (Inc)  1 (Cor)")
print(f"Actual 0  {cm[0,0]:6d}  {cm[0,1]:6d}")
print(f"Actual 1  {cm[1,0]:6d}  {cm[1,1]:6d}")

# Calculate per-class metrics
print(f"\nTrue Negatives (correctly identified incorrect): {cm[0,0]}")
print(f"False Positives (incorrectly identified as correct): {cm[0,1]}")
print(f"False Negatives (incorrectly identified as incorrect): {cm[1,0]}")
print(f"True Positives (correctly identified correct): {cm[1,1]}")

print("\n" + "="*60)
print("MODEL PREDICTIONS ANALYSIS")
print("="*60)

# Get prediction confidence distribution
max_probs = np.max(y_pred_probs, axis=1)
print(f"Average prediction confidence: {np.mean(max_probs):.4f}")
print(f"Min prediction confidence: {np.min(max_probs):.4f}")
print(f"Max prediction confidence: {np.max(max_probs):.4f}")

# Show confidence distribution by class
correct_predictions = y_pred_classes == y_true
print(f"\nCorrect predictions: {np.sum(correct_predictions)} / {len(y_true)} ({np.mean(correct_predictions):.3f})")
print(f"Average confidence for correct predictions: {np.mean(max_probs[correct_predictions]):.4f}")
print(f"Average confidence for incorrect predictions: {np.mean(max_probs[~correct_predictions]):.4f}")


OVERALL MODEL PERFORMANCE
Accuracy: 0.8622
Precision: 0.8871
Recall: 0.8622
F1-Score: 0.8679

DETAILED CLASSIFICATION REPORT
               precision    recall  f1-score   support

Incorrect (0)       0.96      0.85      0.90      4317
  Correct (1)       0.67      0.90      0.77      1465

     accuracy                           0.86      5782
    macro avg       0.82      0.87      0.83      5782
 weighted avg       0.89      0.86      0.87      5782


CONFUSION MATRIX
           Predicted
         0 (Inc)  1 (Cor)
Actual 0    3670     647
Actual 1     150    1315

True Negatives (correctly identified incorrect): 3670
False Positives (incorrectly identified as correct): 647
False Negatives (incorrectly identified as incorrect): 150
True Positives (correctly identified correct): 1315

MODEL PREDICTIONS ANALYSIS
Average prediction confidence: 0.8886
Min prediction confidence: 0.5003
Max prediction confidence: 0.9991

Correct predictions: 4985 / 5782 (0.862)
Average confidence for corre

In [18]:
# ========================================
# EXAMPLE PREDICTIONS WITH ACTUAL TEXT
# ========================================

print("="*80)
print("EXAMPLE PREDICTIONS (Text + Predictions)")
print("="*80)

# Show examples of correct and incorrect predictions
def show_prediction_examples(texts, true_labels, pred_classes, pred_probs, title, indices=None, n_examples=10):
    print(f"\n{title}")
    print("-" * len(title))
    
    if indices is None:
        indices = range(min(n_examples, len(texts)))
    
    for i in indices[:n_examples]:
        text = texts[i]
        true_label = true_labels[i]
        pred_label = pred_classes[i]
        confidence = pred_probs[i][pred_label]
        
        status = "✓ CORRECT" if true_label == pred_label else "✗ WRONG"
        true_meaning = "Correct English" if true_label == 1 else "Incorrect English"
        pred_meaning = "Correct English" if pred_label == 1 else "Incorrect English"
        
        print(f"\nExample {i+1}: {status}")
        print(f"Text: \"{text[:100]}{'...' if len(text) > 100 else ''}\"")
        print(f"True Label: {true_label} ({true_meaning})")
        print(f"Predicted: {pred_label} ({pred_meaning}) [Confidence: {confidence:.3f}]")

# Show correct predictions
correct_indices = np.where(y_pred_classes == y_true)[0]
show_prediction_examples(test_sentences, y_true, y_pred_classes, y_pred_probs, 
                        "CORRECTLY PREDICTED EXAMPLES", correct_indices, 5)

# Show incorrect predictions  
incorrect_indices = np.where(y_pred_classes != y_true)[0]
show_prediction_examples(test_sentences, y_true, y_pred_classes, y_pred_probs,
                        "INCORRECTLY PREDICTED EXAMPLES", incorrect_indices, 5)

# Show high confidence predictions
high_conf_indices = np.where(np.max(y_pred_probs, axis=1) > 0.9)[0]
show_prediction_examples(test_sentences, y_true, y_pred_classes, y_pred_probs,
                        "HIGH CONFIDENCE PREDICTIONS (>90%)", high_conf_indices, 5)

# Show low confidence predictions  
low_conf_indices = np.where(np.max(y_pred_probs, axis=1) < 0.7)[0]
show_prediction_examples(test_sentences, y_true, y_pred_classes, y_pred_probs,
                        "LOW CONFIDENCE PREDICTIONS (<70%)", low_conf_indices, 5)


EXAMPLE PREDICTIONS (Text + Predictions)

CORRECTLY PREDICTED EXAMPLES
----------------------------

Example 1: ✓ CORRECT
Text: "i've got some bad news about the bike you lent me."
True Label: 1 (Correct English)
Predicted: 1 (Correct English) [Confidence: 0.799]

Example 2: ✓ CORRECT
Text: "i've got some bad news about the bike lent me."
True Label: 0 (Incorrect English)
Predicted: 0 (Incorrect English) [Confidence: 0.947]

Example 3: ✓ CORRECT
Text: "i've some bad news about the bike you got lent me."
True Label: 0 (Incorrect English)
Predicted: 0 (Incorrect English) [Confidence: 0.999]

Example 4: ✓ CORRECT
Text: "i've got any bad news about the bike you lent me."
True Label: 0 (Incorrect English)
Predicted: 0 (Incorrect English) [Confidence: 0.524]

Example 5: ✓ CORRECT
Text: "what's that?"
True Label: 1 (Correct English)
Predicted: 1 (Correct English) [Confidence: 0.836]

INCORRECTLY PREDICTED EXAMPLES
------------------------------

Example 12: ✗ WRONG
Text: "he fell on the way t

In [19]:
# ========================================
# INTERACTIVE INFERENCE FUNCTION
# ========================================

def predict_sentence_correctness(text, model=model, tokenizer=tokenizer):
    """
    Predict whether a sentence is grammatically correct or incorrect.

    Args:
        text (str): The sentence to evaluate
        model: The trained model
        tokenizer: The tokenizer

    Returns:
        dict: Dictionary containing prediction, confidence, and probabilities
    """
    # Ensure model and tensors are on the same device
    device = next(model.parameters()).device
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
    # Move input tensors to the same device as the model
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get model predictions
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = F.softmax(logits, dim=-1)
        predicted_class = torch.argmax(logits, dim=-1).item()
        confidence = probabilities[0][predicted_class].item()

    # Interpret results
    prediction = "Correct English" if predicted_class == 1 else "Incorrect English"
    prob_correct = probabilities[0][1].item()
    prob_incorrect = probabilities[0][0].item()

    return {
        "text": text,
        "prediction": prediction,
        "predicted_class": predicted_class,
        "confidence": confidence,
        "prob_correct": prob_correct,
        "prob_incorrect": prob_incorrect
    }

# Test the function with some example sentences
print("="*80)
print("TESTING SINGLE SENTENCE PREDICTIONS")
print("="*80)

test_examples = [
    "I am going to the store.",           # Correct
    "I is going to the store.",           # Incorrect (subject-verb disagreement)
    "She writes beautiful poems.",        # Correct  
    "She write beautiful poems.",         # Incorrect (subject-verb disagreement)
    "The cat is sleeping on the sofa.",   # Correct
    "The cat are sleeping on the sofa.",  # Incorrect (subject-verb disagreement)
    "How are you doing today?",           # Correct
    "How is you doing today?",            # Incorrect
]

for i, text in enumerate(test_examples, 1):
    result = predict_sentence_correctness(text)
    print(f"\nExample {i}:")
    print(f"Text: \"{result['text']}\"")
    print(f"Prediction: {result['prediction']} (Class: {result['predicted_class']})")
    print(f"Confidence: {result['confidence']:.3f}")
    print(f"Probability Correct: {result['prob_correct']:.3f}")
    print(f"Probability Incorrect: {result['prob_incorrect']:.3f}")

print("\n" + "="*80)
print("You can now use the predict_sentence_correctness() function")
print("to test any sentence! Example:")
print('result = predict_sentence_correctness("Your sentence here")')
print("="*80)


TESTING SINGLE SENTENCE PREDICTIONS

Example 1:
Text: "I am going to the store."
Prediction: Correct English (Class: 1)
Confidence: 0.820
Probability Correct: 0.820
Probability Incorrect: 0.180

Example 2:
Text: "I is going to the store."
Prediction: Incorrect English (Class: 0)
Confidence: 0.998
Probability Correct: 0.002
Probability Incorrect: 0.998

Example 3:
Text: "She writes beautiful poems."
Prediction: Correct English (Class: 1)
Confidence: 0.601
Probability Correct: 0.601
Probability Incorrect: 0.399

Example 4:
Text: "She write beautiful poems."
Prediction: Incorrect English (Class: 0)
Confidence: 0.999
Probability Correct: 0.001
Probability Incorrect: 0.999

Example 5:
Text: "The cat is sleeping on the sofa."
Prediction: Incorrect English (Class: 0)
Confidence: 0.975
Probability Correct: 0.025
Probability Incorrect: 0.975

Example 6:
Text: "The cat are sleeping on the sofa."
Prediction: Incorrect English (Class: 0)
Confidence: 0.998
Probability Correct: 0.002
Probability Inc