# Imports and Installs

In [None]:
import pandas as pd
import numpy as np
import re
import os
import difflib
import nltk
nltk.download('punkt_tab')
from collections import Counter, defaultdict
from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')
try:
    nltk.data.find('taggers/averaged_perceptron_tagger')
except LookupError:
    nltk.download('averaged_perceptron_tagger')
try:
    nltk.data.find('taggers/averaged_perceptron_tagger_eng')
except LookupError:
    nltk.download('averaged_perceptron_tagger_eng')
try:
    nltk.data.find('corpra/wordnet')
except LookupError:
    nltk.download('wordnet')


!pip install pandas transformers datasets accelerate jiwer scikit-learn sentencepiece
!pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

import torch
from sklearn.model_selection import train_test_split
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
from datasets import Dataset
from jiwer import wer

from google.colab import drive

from nltk.translate.bleu_score import sentence_bleu

# Model Config

In [None]:
MODEL_CHECKPOINT = "t5-base"  # t5-small for speed/memory. 't5-base' for better results but its is slow.
BATCH_SIZE = 32
EPOCHS = 20
MAX_INPUT_LENGTH = 128
MAX_TARGET_LENGTH = 128

# 1. Prepare Dataset

## 1.1 Load and Clean Dataset

In [None]:
drive.mount('/content/drive')

xlsx_path = "/content/drive/MyDrive/NLP Assignment Submission/Spell_Correction_for_ASR_Noun_Enhancement_assignment_dataset.xlsx"


if not torch.cuda.is_available():
    print("WARNING: GPU not found. Training will be slow. Go to Runtime > Change runtime type > T4 GPU.")
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

In [None]:

# Data Cleaning function
def clean_text(text):
  if not isinstance(text, str):
    return str(text)

  # 1. Removing leading/trailing quotes (often CSV artifacts)
  text = text.strip('"')

  # 2. Normalizing distinct punctuation
  text = text.replace("’", "'")  # Smart apostrophe to standard

  # Only removing these at the end as they are stop-char.
  # In between the sentence, they are actually part of mispelled so need to be conserved
  text = text.strip(",")  # En-dash to standard hyphen
  text = text.strip("?")  # En-dash to standard hyphen
  text = text.strip(".")  # En-dash to standard hyphen

  # 3. Removing bullet points
  text = text.strip("•")

  # Removing leading-trailing space
  text = text.strip()

  # 4. Collapsing multiple spaces into one
  text = re.sub(r'\s+', ' ', text).strip()

  return text

def load_and_clean_data(path):
    print(f"Loading data from {path}...")
    try:
        df = pd.read_excel(path)

        df.columns = ['correct', 'incorrect']

        # Cleaning inputs
        df['correct'] = df['correct'].astype(str).apply(clean_text)
        df['incorrect'] = df['incorrect'].astype(str).apply(clean_text)

        # Preparing for T5 training: Adding prefix
        df['input_text'] = "fix spelling: " + df['incorrect']
        df['target_text'] = df['correct']

        return df
    except Exception as e:
        print(f"Error loading file: {e}")
        return pd.DataFrame({
            'input_text': ["fix spelling: test"], 'target_text': ["test"]
        })

# Read & clean Data
df = load_and_clean_data(xlsx_path)

## 2.1 Prepare Train, test and validation dataset

In [None]:
# Splitting Data (70% Train, 15% Val, 15% Test)
train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)

# Converting to HF Datasets
train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

# Tokenization
tokenizer = T5Tokenizer.from_pretrained(MODEL_CHECKPOINT)

def preprocess_function(examples):
    inputs = examples["input_text"]
    targets = examples["target_text"]

    model_inputs = tokenizer(inputs, max_length=MAX_INPUT_LENGTH, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=MAX_TARGET_LENGTH, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

print("Tokenizing datasets...")
tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_val = val_dataset.map(preprocess_function, batched=True)
tokenized_test = test_dataset.map(preprocess_function, batched=True)

## 2.2 Model Setup and Training

In [None]:
# Model Setup
model = T5ForConditionalGeneration.from_pretrained(MODEL_CHECKPOINT)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# Training Parameters
args = Seq2SeqTrainingArguments(
    output_dir="./t5-medical-correction",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=EPOCHS,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(), # Use mixed precision if on GPU
    logging_steps=100,
    report_to="none"
)

# Initializing Trainer
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

# Training
print("\nStarting Training...")
trainer.train()

# 3. Inference

## 3.1 Inference on Test Dataset

In [None]:
predict_results = trainer.predict(tokenized_test)
if isinstance(predict_results.predictions, tuple):
    predictions = predict_results.predictions[0]
else:
    predictions = predict_results.predictions

# Decoding predictions
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(tokenized_test['labels'], skip_special_tokens=True)

# 3.2 Inference on sample test dataset

In [None]:
print("\n--- Example Corrections ---")
for i in range(5):
    print(f"Input:    {test_df.iloc[i]['incorrect']}")
    print(f"Pred:     {decoded_preds[i]}")
    print(f"Actual:   {decoded_labels[i]}")
    print("-" * 40)

# 4. Evaluation

## 4.1 Advanced Model

#### WER and accuracy

In [None]:
word_error_rate = wer(decoded_labels, decoded_preds)
accuracy = 1 - word_error_rate

print(f"\nResults:")
print(f"Word Error Rate (WER): {word_error_rate:.4f}")
print(f"Accuracy (Approx): {accuracy:.4f}")

In [None]:
def calculate_wer(reference, hypothesis):
    """
    Calculate Word Error Rate using jiwer.
    Handles empty strings to prevent errors.
    """
    if not reference.strip():
        return 0.0 if not hypothesis.strip() else 1.0
    return wer(reference, hypothesis)

def get_nouns(text):
    """
    Extract nouns from text using NLTK POS tagging.
    """
    tokens = nltk.word_tokenize(str(text))
    tags = nltk.pos_tag(tokens)
    return [word for word, pos in tags if pos.startswith('NN')]

def evaluate_advanced_model(preds, labels):
    """
    Detail evaluation for Advanced Mode.
    """
    # Create DataFrame for easier processing
    eval_df = pd.DataFrame({
        'target_text': labels,
        'predicted_text': preds
    })

    print("Extracting nouns from target labels for evaluation...")
    eval_df['target_nouns'] = eval_df['target_text'].apply(get_nouns)

    wer_scores = []
    cer_scores = []
    bleu_scores = []
    noun_scores = []

    print(f"\n--- Detailed Evaluation for Advanced Model ---")

    for _, row in tqdm(eval_df.iterrows(), total=len(eval_df)):
        ref = str(row['target_text'])
        hyp = str(row['predicted_text'])

        # 1. WER (Word Error Rate)
        wer_scores.append(calculate_wer(ref, hyp))

        # 2. CER (Character Error Rate - Approximate)
        # SequenceMatcher.ratio() returns similarity (0-1). CER is roughly 1 - Similarity.
        cer_scores.append(1 - difflib.SequenceMatcher(None, ref, hyp).ratio())

        # 3. BLEU Score (BLEU-1)
        try:
            # BLEU expects tokenized lists. Using split() or word_tokenize()
            ref_tokens = [nltk.word_tokenize(ref)]
            hyp_tokens = nltk.word_tokenize(hyp)
            # Weights for BLEU-1 (unigram match)
            bleu = sentence_bleu(ref_tokens, hyp_tokens, weights=(1, 0, 0, 0))
        except:
            bleu = 0
        bleu_scores.append(bleu)

        # 4. Noun Recall
        target_nouns = row['target_nouns']
        if target_nouns:
            # Check if target nouns appear in hypothesis
            # Normalize to lower case for fair comparison
            target_nouns_lower = [n.lower() for n in target_nouns]

            # Tokenize hypothesis to ensure we match whole words (e.g., "pill" matches "pill." or "pill")
            hyp_tokens_lower = set(w.lower() for w in nltk.word_tokenize(hyp))

            # Count matches
            matches = sum(1 for n in target_nouns_lower if n in hyp_tokens_lower)
            noun_scores.append(matches / len(target_nouns))
        else:
            # If no nouns were in the target, we didn't "miss" any
            noun_scores.append(1.0)

    # Print Aggregated Results
    print("\n" + "="*30)
    print(f"FINAL RESULTS (Advanced Model)")
    print("="*30)
    print(f"Mean WER:         {np.mean(wer_scores):.4f} (Lower is better)")
    print(f"Mean CER:         {np.mean(cer_scores):.4f} (Lower is better)")
    print(f"Mean BLEU-1:      {np.mean(bleu_scores):.4f} (Higher is better)")
    print(f"Mean Noun Recall: {np.mean(noun_scores):.4f} (Higher is better)")
    print("="*30)

evaluate_advanced_model(decoded_preds, decoded_labels)