In [3]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch

# Load the pre-trained T5-small model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")

# Define the task-specific head
task_prefix = "grammar and spelling correction: "

# Define the training data
training_data = [
    ("Teh quik brown fox jumpd over teh lazy dog.", "The quick brown fox jumped over the lazy dog."),
    ("I hav two appels and three orngez.", "I have two apples and three oranges."),
    ("Thay went too the park.", "They went to the park."),
    ("Its a beautifull day.", "It's a beautiful day."),
    ("I cant wait for the weeknd.", "I can't wait for the weekend."),
    ("Thier house is very nice.", "Their house is very nice."),
    ("Definately going to the party.", "Definitely going to the party."),
    ("I dont no what to do.", "I don't know what to do."),
    ("Whos going to the movies?", "Who's going to the movies?"),
    ("I dint do my homework.", "I didn't do my homework."),
]

# Preprocess the data
def preprocess_data(input_texts, target_texts):
    inputs = [task_prefix + text for text in input_texts]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length", return_tensors="pt")
    labels = tokenizer(target_texts, max_length=512, truncation=True, padding="max_length", return_tensors="pt")["input_ids"]
    return model_inputs, labels

# Prepare the training data
input_texts, target_texts = zip(*training_data)
train_inputs, train_labels = preprocess_data(input_texts, target_texts)

# Fine-tune the model
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
num_epochs = 10

for epoch in range(num_epochs):
    outputs = model(input_ids=train_inputs["input_ids"], attention_mask=train_inputs["attention_mask"], labels=train_labels)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")

# Evaluate the model
print("-----Starting model eval mode - ready to make predictions")
model.eval()
print("-----Finished model eval")


# Inference method
def correct_text(input_text):
    input_ids = tokenizer.encode(task_prefix + input_text, return_tensors="pt")
    output_ids = model.generate(input_ids, max_length=512)[0]
    corrected_text = tokenizer.decode(output_ids, skip_special_tokens=True)
    return corrected_text


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Epoch 1/10, Loss: 10.7148
Epoch 2/10, Loss: 8.3474
Epoch 3/10, Loss: 6.0314
Epoch 4/10, Loss: 7.2457
Epoch 5/10, Loss: 5.4176
Epoch 6/10, Loss: 3.2987
Epoch 7/10, Loss: 2.8367
Epoch 8/10, Loss: 1.9369


In [2]:
#run inference
input_text = "I draaanks waatttere"
input_ids = tokenizer.encode(task_prefix + input_text, return_tensors="pt")
print("------Input IDs",input_ids)
output_ids = model.generate(input_ids, max_length=512)[0]
print("------Output IDs",output_ids)

corrected_text = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Input: {input_text}")
print(f"Corrected: {corrected_text}")

------Input IDs tensor([[19519,    11, 19590, 11698,    10,    27,     3,  3515,     9,  5979,
             7,  8036,   144,    17,   449,    15,     1]])
------Output IDs tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0,