In [None]:
# %% [markdown]
# # IELTS Grammar & Accuracy Classifier Notebook

# %% [code] {"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-10-01T18:52:13.174346Z","iopub.execute_input":"2025-10-01T18:52:13.175013Z","iopub.status.idle":"2025-10-01T18:54:42.993796Z","shell.execute_reply.started":"2025-10-01T18:52:13.174990Z","shell.execute_reply":"2025-10-01T18:54:42.993182Z"}}
# Install required packages (Kaggle usually has these, but just in case)
!pip install transformers==4.35.0 accelerate sentencepiece sacremoses spacy==3.6.0 regex datasets
!python -m spacy download en_core_web_sm

import pandas as pd
import numpy as np
import re, math, json, time, datetime
from dataclasses import dataclass, asdict

import spacy
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from difflib import SequenceMatcher

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, mean_absolute_error
from sklearn.metrics import cohen_kappa_score

# %% [code] {"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-10-01T18:54:42.994847Z","iopub.execute_input":"2025-10-01T18:54:42.995321Z","iopub.status.idle":"2025-10-01T18:54:43.037408Z","shell.execute_reply.started":"2025-10-01T18:54:42.995296Z","shell.execute_reply":"2025-10-01T18:54:43.036699Z"}}
df = pd.read_csv("/kaggle/input/ielts-writing-dataset/cook.csv")
print("Dataset shape:", df.shape)
df.head()

# %% [code] {"execution":{"iopub.status.busy":"2025-10-01T18:54:45.824677Z","iopub.execute_input":"2025-10-01T18:54:45.824966Z","iopub.status.idle":"2025-10-01T18:55:17.535299Z","shell.execute_reply.started":"2025-10-01T18:54:45.824943Z","shell.execute_reply":"2025-10-01T18:55:17.534096Z"}}
# ----------- Load NLP models -----------
nlp = spacy.load("en_core_web_sm")
GEC_MODEL_NAME = "vennify/t5-base-grammar-correction"
tokenizer = AutoTokenizer.from_pretrained(GEC_MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(GEC_MODEL_NAME)

# %% [code] {"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-10-01T18:56:12.861567Z","iopub.execute_input":"2025-10-01T18:56:12.862644Z","iopub.status.idle":"2025-10-01T18:56:12.873473Z","shell.execute_reply.started":"2025-10-01T18:56:12.862607Z","shell.execute_reply":"2025-10-01T18:56:12.872918Z"}}
# ----------- Utility Functions -----------

def sentence_tokenize(text: str):
    doc = nlp(text)
    return [sent.text.strip() for sent in doc.sents]

def apply_gec_t5(texts, max_length=256):
    batch = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
    outs = model.generate(**batch, max_length=max_length, num_beams=4, early_stopping=True)
    corrected = [tokenizer.decode(o, skip_special_tokens=True, clean_up_tokenization_spaces=True) for o in outs]
    return corrected

def token_edit_ops(orig: str, corrected: str):
    a = orig.split()
    b = corrected.split()
    s = SequenceMatcher(a=a, b=b)
    ops = []
    for tag, i1, i2, j1, j2 in s.get_opcodes():
        ops.append((tag, i1, j1, i2 - i1, j2 - j1))
    return ops

PUNCTUATION_CHARS = set(['.', ',', '?', '!', ';', ':'])

def punctuation_accuracy(orig: str, corrected: str) -> float:
    orig_p = [c for c in orig if c in PUNCTUATION_CHARS]
    corr_p = [c for c in corrected if c in PUNCTUATION_CHARS]
    if not orig_p and not corr_p:
        return 1.0
    matches = sum(1 for a, b in zip(orig_p, corr_p) if a == b)
    denom = max(len(orig_p), len(corr_p))
    if denom == 0:
        return 1.0
    return matches / denom

def complexity_metrics(text: str):
    doc = nlp(text)
    sents = list(doc.sents)
    num_sents = max(1, len(sents))
    total_tokens = len([t for t in doc if not t.is_space])
    avg_sent_len = total_tokens / num_sents

    subordinators = set(["although","because","since","while","whereas","unless",
                         "where","after","before","though","if","that"])
    sub_count, clause_count = 0, 0
    for sent in sents:
        clause_count += 1
        for tok in sent:
            if tok.text.lower() in subordinators:
                sub_count += 1
    complex_structure_ratio = sub_count / clause_count if clause_count > 0 else 0.0

    return {
        "num_sents": num_sents,
        "total_tokens": total_tokens,
        "avg_sent_len": avg_sent_len,
        "complex_structure_ratio": complex_structure_ratio,
    }

# %% [code] {"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-10-01T18:56:12.874810Z","iopub.execute_input":"2025-10-01T18:56:12.875133Z","iopub.status.idle":"2025-10-01T18:56:12.898680Z","shell.execute_reply.started":"2025-10-01T18:56:12.875116Z","shell.execute_reply":"2025-10-01T18:56:12.898133Z"}}

@dataclass
class IELTSGrammarMetrics:
    error_density: float
    error_free_sentence_ratio: float
    complex_structure_ratio: float
    punctuation_accuracy: float
    avg_sent_len: float
    total_words: int
    total_sentences: int
    edits_count: int

# %% [code] {"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-10-01T18:56:12.899508Z","iopub.execute_input":"2025-10-01T18:56:12.899776Z","iopub.status.idle":"2025-10-01T18:56:12.917820Z","shell.execute_reply.started":"2025-10-01T18:56:12.899757Z","shell.execute_reply":"2025-10-01T18:56:12.917209Z"}}

def analyze_text(text: str):
    sents = sentence_tokenize(text)
    corrected_sents = apply_gec_t5(sents)

    total_edits, error_free, punc_accs = 0, 0, []
    for o, c in zip(sents, corrected_sents):
        ops = token_edit_ops(o, c)
        edits_here = sum(1 for op in ops if op[0] != 'equal')
        if edits_here == 0:
            error_free += 1
        total_edits += edits_here
        punc_accs.append(punctuation_accuracy(o, c))

    total_words = len(text.split())
    total_sentences = len(sents)
    error_density = total_edits / max(1, total_words)
    error_free_sentence_ratio = error_free / max(1, total_sentences)
    punc_accuracy = sum(punc_accs) / max(1, len(punc_accs))
    comp = complexity_metrics(text)

    metrics = IELTSGrammarMetrics(
        error_density=round(error_density, 4),
        error_free_sentence_ratio=round(error_free_sentence_ratio, 4),
        complex_structure_ratio=round(comp['complex_structure_ratio'], 4),
        punctuation_accuracy=round(punc_accuracy, 4),
        avg_sent_len=round(comp['avg_sent_len'], 2),
        total_words=total_words,
        total_sentences=total_sentences,
        edits_count=total_edits
    )
    return metrics

# %% [code] {"jupyter":{"outputs_hidden":false},"execution":{"iopub.status.busy":"2025-10-01T18:56:12.918638Z","iopub.execute_input":"2025-10-01T18:56:12.918955Z"}}

rows = []
start_time = time.time()

for idx, row in df.iterrows():
    essay = str(row["essay"])
    gra_score = int(row["gra_score"])
    metrics = analyze_text(essay)
    rows.append({**asdict(metrics), "true_gra": gra_score})

elapsed = time.time() - start_time
print(f"Feature extraction done in {elapsed:.2f} seconds")

results_df = pd.DataFrame(rows)
print("Features shape:", results_df.shape)
results_df.head()

# %% [code] {"jupyter":{"outputs_hidden":false}}

X = results_df.drop(columns=["true_gra"])
y = results_df["true_gra"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.7, random_state=42)

# %% [code] {"jupyter":{"outputs_hidden":false}}

clf = RandomForestClassifier(n_estimators=300, random_state=42)
clf.fit(X_train, y_train)

# %% [code] {"jupyter":{"outputs_hidden":false}}

y_pred = clf.predict(X_test)
print("\nClassification Report:\n", classification_report(y_test, y_pred))
print("MAE:", mean_absolute_error(y_test, y_pred))
print("QWK:", cohen_kappa_score(y_test, y_pred, weights="quadratic"))

# %% [code] {"jupyter":{"outputs_hidden":false}}
def predict_band(text: str) -> int:
    metrics = analyze_text(text)
    x = np.array([list(asdict(metrics).values())])
    return int(clf.predict(x)[0])

# Example test
sample = """Some people think that children should begin their formal education at a very early age.
However, others believe that preschooling is unnecessary and that young children learn best through play.
In my opinion, both points have merits but formal education should not be rushed."""
print("\nPredicted GRA Band for sample essay:", predict_band(sample))