In [4]:
!pip install -q transformers accelerate sentencepiece bitsandbytes
!pip install -q pillow

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
model_id = "google/gemma-3-1b-it"  # Use "google/gemma-3b-it" if you have Colab Pro/A100
tokenizer = AutoTokenizer.from_pretrained(model_id, token="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")
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16,
    load_in_4bit=True,  # 4-bit quantization to fit in T4/A100
    token="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"
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


In [19]:
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.eval()

Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear4bit(in_features=1152, out_features=1024, bias=False)
          (k_proj): Linear4bit(in_features=1152, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=1152, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=1024, out_features=1152, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear4bit(in_features=1152, out_features=6912, bias=False)
          (up_proj): Linear4bit(in_features=1152, out_features=6912, bias=False)
          (down_proj): Linear4bit(in_features=6912, out_features=1152, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_l

In [20]:
# ===================================================================
# 2. BATCHED Helper Functions (Modified from your code)
# ===================================================================
def make_prompts(sentences: list[str]) -> list[str]:
    """Creates a list of prompts for a batch of sentences."""
    template = """You are a grammar checker.
Decide if the sentence below is grammatically correct.
Answer with only one word: "Correct" or "Incorrect".

Sentence: "{sentence}"
Answer:"""
    return [template.format(sentence=s) for s in sentences]

def check_sentences_batched(sentences: list[str]) -> list[str]:
    """
    Checks a BATCH of sentences for grammatical correctness.
    """
    # 1. Create prompts for the whole batch
    prompts = make_prompts(sentences)

    # 2. Tokenize the entire batch at once
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)

    # 3. Generate outputs for the entire batch
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=5,
            temperature=0.0,
            do_sample=False
        )

    # 4. Decode all results at once
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # 5. Parse the answer for each result in the batch
    batch_results = []
    for text in generated_texts:
        try:
            answer = text.split("Answer:")[-1].strip().split()[0]
            if answer.lower().startswith("c"):
                batch_results.append("Correct")
            else:
                batch_results.append("Incorrect")
        except IndexError:
            # Handle cases where the model might fail to generate a proper answer
            batch_results.append("Error")

    return batch_results

In [21]:
# ===================================================================
# 3. Load dataset and run the batched check
# ===================================================================
try:
    # Use your specified column name 'sentences'
    df = pd.read_csv("sentences.csv")
    df.dropna(subset=['sentences'], inplace=True)
    sentences_to_check = df['sentences'].tolist()
    print(f"Loaded {len(sentences_to_check)} sentences.")

    BATCH_SIZE = 32  # You can adjust this based on your GPU memory
    results = []

    # Use tqdm to create a progress bar for the BATCHES
    for i in tqdm(range(0, len(sentences_to_check), BATCH_SIZE), desc="Checking grammar in batches"):
        batch = sentences_to_check[i : i + BATCH_SIZE]
        results.extend(check_sentences_batched(batch))

    df["grammar_check"] = results

    # Save the final results
    df.to_csv("sentences_with_grammar_batched.csv", index=False)
    print("\n✅ Done! Saved to sentences_with_grammar_batched.csv")

except FileNotFoundError:
    print("❌ Error: 'sentences.csv' not found. Please make sure the file is uploaded.")
except KeyError:
    print("❌ Error: The CSV file must have a column named 'sentences'.")

Loaded 89134 sentences.


Checking grammar in batches:   0%|          | 0/2786 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Checking grammar in batches: 100%|██████████| 2786/2786 [1:03:03<00:00,  1.36s/it]



✅ Done! Saved to sentences_with_grammar_batched.csv


In [4]:
import pandas as pd

# Define the names of your input and output files
input_filename = 'sentences_with_grammar_batched.csv'
output_filename = 'Gramitically_correct_dataset.csv'

try:
    df = pd.read_csv(input_filename)
    filtered_df = df[df['grammar_check'] == 'Correct']
    final_df = filtered_df[['uid', 'sentences']].copy()
    final_df.to_csv(output_filename, index=False)

    print(f"✅ Success! Filtered data without the 'grammar_check' column has been saved to '{output_filename}'")

except FileNotFoundError:
    print(f"❌ Error: The file '{input_filename}' was not found. Please make sure it's in the same directory as the script.")
except KeyError as e:
    print(f"❌ Error: A required column was not found. Please check your column headers. Details: {e}")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

✅ Success! Filtered data without the 'grammar_check' column has been saved to 'Gramitically_correct_dataset.csv'
