# Writing Assistance Tools: Text Correction using SLM and LLM

This notebook implements a **writing assistance pipeline** focused on grammar and spelling correction, combining:
- **Small Language Model (SLM)** for grammatical acceptability detection (using a DistilBERT variant fine-tuned on CoLA),
- **Large Language Model / Sequence-to-Sequence model (LLM)** for generative correction (using T5-small fine-tuned for grammar correction).

The pipeline demonstrates how to detect problematic sentences and propose corrected rewrites, and evaluates corrections using standard metrics such as BLEU, Levenshtein (edit) distance, and simple overlap-based scores.

The design is modular so it can be extended to real datasets (e.g., JFLEG, BEA, CoNLL), integrated with user-facing interfaces, or deployed in production.

In [None]:
# Install required packages (run this cell once)
import sys
import subprocess

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Core NLP / evaluation libraries
required = [
    "transformers>=4.30.0",
    "datasets",
    "evaluate",
    "rouge_score",
    "nltk"
]
for pkg in required:
    try:
        __import__(pkg.split('>=')[0])
    except ImportError:
        install(pkg)

# Download NLTK data needed
import nltk
nltk.download('punkt')

In [None]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, DistilBertTokenizer, AutoModelForSequenceClassification, pipeline
from transformers import Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset
import nltk
from nltk.tokenize import word_tokenize
from evaluate import load as load_metric
import math

# Helper: simple Levenshtein distance implementation
def levenshtein_distance(a: str, b: str) -> int:
    # classic dynamic programming
    n, m = len(a), len(b)
    if n == 0:
        return m
    if m == 0:
        return n
    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n + 1):
        dp[i][0] = i
    for j in range(m + 1):
        dp[0][j] = j
    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = 0 if a[i - 1] == b[j - 1] else 1
            dp[i][j] = min(
                dp[i - 1][j] + 1,      # deletion
                dp[i][j - 1] + 1,      # insertion
                dp[i - 1][j - 1] + cost  # substitution
            )
    return dp[n][m]

# Token-based similarity (optional)
def word_overlap_score(pred: str, reference: str) -> float:
    p_tokens = set(word_tokenize(pred.lower()))
    r_tokens = set(word_tokenize(reference.lower()))
    if not r_tokens:
        return 0.0
    overlap = p_tokens.intersection(r_tokens)
    return len(overlap) / len(r_tokens)

# Load evaluation metrics
bleu = load_metric('bleu')
rouge = load_metric('rouge')

In [None]:
# Create a small synthetic dataset of erroneous -> corrected pairs
examples = [
    {
        "source": "She go to school every day.",
        "target": "She goes to school every day."
    },
    {
        "source": "I has a meeting tomorrow.",
        "target": "I have a meeting tomorrow."
    },
    {
        "source": "They is playing football.",
        "target": "They are playing football."
    },
    {
        "source": "He don't like apples.",
        "target": "He doesn't like apples."
    },
    {
        "source": "The cat chase the mouse.",
        "target": "The cat chases the mouse."
    },
    {
        "source": "This sentences are wrong.",
        "target": "These sentences are wrong."
    },
    {
        "source": "We was late to the party.",
        "target": "We were late to the party."
    },
    {
        "source": "Your going to love it.",
        "target": "You're going to love it."
    },
    {
        "source": "I will finished it soon.",
        "target": "I will finish it soon."
    },
    {
        "source": "Its a beautiful day.",
        "target": "It's a beautiful day."
    }
]

dataset = Dataset.from_list([
    {
        "input_text": "correct: " + ex["source"],
        "target_text": ex["target"]
    }
    for ex in examples
])

# Quick peek
print('Sample example:', dataset[0])

In [None]:
# Load generative correction model: T5-small
t5_model_name = "t5-small"
t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)
t5_model = T5ForConditionalGeneration.from_pretrained(t5_model_name)

# Load SLM for grammatical acceptability detection (CoLA)
# We'll use a model fine-tuned on CoLA for acceptability judgments
cola_model_name = "textattack/distilbert-base-uncased-CoLA"
cola_tokenizer = DistilBertTokenizer.from_pretrained(cola_model_name)
cola_model = AutoModelForSequenceClassification.from_pretrained(cola_model_name)
cola_pipe = pipeline("text-classification", model=cola_model, tokenizer=cola_tokenizer, return_all_scores=False)

In [None]:
# Tokenize dataset for T5 training
max_input_length = 64
max_target_length = 64

def preprocess_function(example):
    model_inputs = t5_tokenizer(example['input_text'], truncation=True, padding='max_length',
                                max_length=max_input_length)
    with t5_tokenizer.as_target_tokenizer():
        labels = t5_tokenizer(example['target_text'], truncation=True, padding='max_length',
                              max_length=max_target_length)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=False)  # small, so batched=False is okay
# Data collator handles padding properly
data_collator = DataCollatorForSeq2Seq(tokenizer=t5_tokenizer, model=t5_model)

In [None]:
# Set up training arguments (small scale for demonstration)
training_args = TrainingArguments(
    output_dir="./t5-writing-assist",
    per_device_train_batch_size=2,
    num_train_epochs=3,
    learning_rate=5e-5,
    weight_decay=0.01,
    save_strategy="no",
    logging_strategy="steps",
    logging_steps=5,
    report_to=[]  # disable external logging (e.g., wandb) for portability
)

trainer = Trainer(
    model=t5_model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=t5_tokenizer,
    data_collator=data_collator
)

# Train (this is small, should finish quickly)
trainer.train()

In [None]:
# Function to generate correction with T5
def correct_with_t5(sentence: str, num_beams: int = 3) -> str:
    input_text = "correct: " + sentence
    inputs = t5_tokenizer(input_text, return_tensors='pt', truncation=True, padding=True, max_length=64)
    outputs = t5_model.generate(**inputs, num_beams=num_beams, max_length=64, early_stopping=True)
    corrected = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return corrected

# Function to check grammatical acceptability using CoLA-model (SLM)
def check_acceptability(sentence: str):
    res = cola_pipe(sentence)
    # The label is usually 'acceptable' vs 'unacceptable' depending on model; dsiplay raw
    return res

# Evaluate on synthetic dataset
def evaluate_corrections(examples_list):
    bleu_scores = []
    edit_distances = []
    overlap_scores = []
    references = []
    predictions = []
    for ex in examples_list:
        src = ex['source']
        tgt = ex['target']
        pred = correct_with_t5(src)
        # BLEU expects list of references (tokenized) and prediction
        ref_tokens = [tgt.split()]
        pred_tokens = pred.split()
        bleu_res = bleu.compute(predictions=[pred_tokens], references=[ref_tokens])
        # Levenshtein on raw strings
        ed = levenshtein_distance(pred, tgt)
        overlap = word_overlap_score(pred, tgt)
        bleu_scores.append(bleu_res['bleu'])
        edit_distances.append(ed)
        overlap_scores.append(overlap)
        references.append(tgt)
        predictions.append(pred)
        print(f"Source: {src}\nTarget: {tgt}\nPredicted: {pred}\nBLEU: {bleu_res['bleu']:.3f}, Edit Distance: {ed}, Overlap: {overlap:.3f}\n---")
    avg_bleu = sum(bleu_scores) / len(bleu_scores)
    avg_edit = sum(edit_distances) / len(edit_distances)
    avg_overlap = sum(overlap_scores) / len(overlap_scores)
    print(f"\nAverage BLEU: {avg_bleu:.3f}, Average Edit Distance: {avg_edit:.2f}, Average Overlap: {avg_overlap:.3f}")

In [None]:
# Combined writing assistance function
def writing_assistant(sentence: str):
    print(f"Input Sentence: {sentence}\n")
    acceptability = check_acceptability(sentence)
    print(f"Grammatical Acceptability (SLM): {acceptability}\n")
    correction = correct_with_t5(sentence)
    print(f"T5 Suggestion: {correction}\n")
    # Simple heuristic: if edit distance is small, suggest minor change, else full rewrite
    ed = levenshtein_distance(correction, sentence)
    if ed == 0:
        print("No changes suggested by the generator.")
    else:
        print(f"Edit distance between input and suggestion: {ed}") 

# Example usage
print('--- Evaluation on synthetic dataset ---')
evaluate_corrections(examples)

print('\n--- Writing assistant demos ---')
writing_assistant("She dont like the movie.")
writing_assistant("I will go to the market yesterday.")
writing_assistant("He has a dogs.")