In [1]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))


True
NVIDIA GeForce RTX 4070 Laptop GPU


In [13]:
import polars as pl
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import re
import torch

# --- Choisir ton moteur de correction mot à mot ---
# Option 1 : Hunspell
try:
    import hunspell
    HUNSPELL_AVAILABLE = True
    h_en = hunspell.HunSpell('/usr/share/hunspell/en_US.dic', '/usr/share/hunspell/en_US.aff')
    h_fr = hunspell.HunSpell('/usr/share/hunspell/fr_FR.dic', '/usr/share/hunspell/fr_FR.aff')
except:
    HUNSPELL_AVAILABLE = False

# Option 2 : PyEnchant
if not HUNSPELL_AVAILABLE:
    import enchant
    d_en = enchant.Dict("en_US")


# --- Fonction mot par mot ---
def correct_word(word, lang='en'):
    if HUNSPELL_AVAILABLE:
        h = h_en if lang == 'en' else h_fr
        if h.spell(word):
            return word
        suggestions = h.suggest(word)
        return suggestions[0] if suggestions else word
    else:
        d = d_en if lang == 'en' else d_fr
        if d.check(word):
            return word
        suggestions = d.suggest(word)
        return suggestions[0] if suggestions else word


def correct_review_fast(review, lang='en'):
    words = review.split()
    return ' '.join([correct_word(w, lang) for w in words])


# --- Découper les textes longs ---
def split_text_into_chunks(text, max_chars=500):
    if not isinstance(text, str) or not text.strip():
        return [text]
    sentences = re.split(r'(?<=[.!?]) +', text)
    chunks = []
    current = ""
    for s in sentences:
        if len(current) + len(s) <= max_chars:
            current += " " + s if current else s
        else:
            if current:
                chunks.append(current)
            current = s
    if current:
        chunks.append(current)
    return chunks


# --- Correction contextuelle avec Transformer léger ---
def init_transformer(model_path="../../models/grammar_correcter"):
    device = 0 if torch.cuda.is_available() else -1
    from transformers import pipeline
    corrector = pipeline("text2text-generation", model=model_path, device=device)
    return corrector


def correct_text_transformer(text, corrector, max_length=1024, chunk_size=500):
    chunks = split_text_into_chunks(text, chunk_size)
    results = corrector(chunks, max_length=max_length)
    corrected = " ".join([r["generated_text"] for r in results])
    return corrected


# --- Pipeline hybride ---
def correct_reviews_hybrid(
    df: pl.DataFrame,
    column_name: str,
    lang='en',
    fast_threshold=500,  # review < fast_threshold chars -> correction rapide
    batch_size_fast=5000,
    batch_size_transformer=8,
    transformer_model_path="./models/grammar_correcter",
    max_length=1024,
    chunk_size=500,
    n_threads=8
):
    # 1️⃣ Pré-correction rapide (Hunspell / PyEnchant) sur toutes les reviews
    print("Step 1: Fast word-level correction (CPU, multithreaded)...")
    reviews = df[column_name].to_list()
    
    def fast_worker(batch):
        return [correct_review_fast(r, lang) for r in batch]
    
    corrected_fast = []
    with ThreadPoolExecutor(max_workers=n_threads) as executor:
        for batch_start in tqdm(range(0, len(reviews), batch_size_fast), desc="Fast CPU correction"):
            batch = reviews[batch_start:batch_start + batch_size_fast]
            corrected_fast.extend(executor.submit(fast_worker, batch).result())
    
    df = df.with_columns(pl.Series(name=column_name, values=corrected_fast))
    
    # 2️⃣ Correction contextuelle avec Transformer sur textes longs
    print("Step 2: Transformer correction for long reviews (GPU)...")
    corrector = init_transformer(transformer_model_path)
    
    final_corrected = corrected_fast.copy()
    
    # Filtrer les reviews longues
    long_indices = [i for i, r in enumerate(corrected_fast) if len(r) > fast_threshold]
    
    for i in tqdm(range(0, len(long_indices), batch_size_transformer), desc="Transformer GPU correction"):
        batch_indices = long_indices[i:i+batch_size_transformer]
        for idx in batch_indices:
            final_corrected[idx] = correct_text_transformer(
                corrected_fast[idx],
                corrector,
                max_length=max_length,
                chunk_size=chunk_size
            )
    
    df_corrected = df.with_columns(pl.Series(name=column_name, values=final_corrected))
    return df_corrected


In [None]:
df = pl.read_csv("../../data/original/dataset/data_accessiblego.csv")

df_corrected = correct_reviews_hybrid(df, "review")

df_corrected.write_csv("../../data/original/dataset/test.csv")


Step 1: Fast word-level correction (CPU, multithreaded)...


Fast CPU correction:   0%|          | 0/1 [00:00<?, ?it/s]