# In this notebook I showcase a simple fine-tuning example of teaching a t5-small model to output the 'b' version of an english word. A 'b' version is the original word, where the letter 'b' appears after every vowel in the original word.

## Install packages

In [None]:
pip install torch transformers datasets accelerate SentencePiece

## Import relevant libraries

In [2]:
import torch

## Define a function that creates a word in the "b" language, and then generate a dataset accordingly

In [3]:
def insert_b_after_vowels(word):
    vowels = "aeiou"
    modified_word = ""
    for letter in word:
        if letter.lower() in vowels:  # Case-insensitive check
            modified_word += letter + 'b' + letter
        else:
            modified_word += letter
    return modified_word

def create_dataset(words):
    dataset = []
    for word in words:
        modified_word = insert_b_after_vowels(word)
        dataset.append({"before": word, "after": modified_word})
    return dataset

## Use pretrained t5-small model and tokenizer

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')
tokenizer.pad_token = tokenizer.eos_token

## Create a function that generates n random words in english

In [5]:
import requests
def get_random_words(n):
    try:
        response = requests.get(f"https://random-word-api.herokuapp.com/word?number={n}")
        if response.status_code == 200:
            return response.json()
        else:
            print("Failed to fetch words")
            return []
    except Exception as e:
        print(f"An error occurred: {e}")
        return []


## create a dataset of n_words pairs of english words and their corresponding 'b' word

In [6]:
# words = ['dog','dog'] for overfit check
n_words = 200000
words = get_random_words(n_words)
dataset = create_dataset(words)

## Encode the before and after tokens

In [7]:
# Adjust the function if needed to handle the entire dataset at once
def encode_examples(dataset, max_length=128, padding="max_length", truncation=True):
    input_encodings = tokenizer([ex["before"] for ex in dataset], max_length=max_length, truncation=truncation, padding=padding, return_tensors="pt")
    labels = tokenizer([ex["after"] for ex in dataset], max_length=max_length, truncation=truncation, padding=padding, return_tensors="pt").input_ids
    # Replace -100 in the labels as padding token
    labels[labels == tokenizer.pad_token_id] = -100
    return {
        "input_ids": input_encodings.input_ids,
        "attention_mask": input_encodings.attention_mask,
        "labels": labels
    }

# Now, call the function with the entire dataset
encoded_dataset = encode_examples(dataset)

## Create Dataset objects for training and evaluation

In [8]:
from datasets import Dataset

total_dataset = Dataset.from_dict(encoded_dataset)
split_dataset = total_dataset.train_test_split(test_size=0.1)
train_dataset = split_dataset['train']
eval_dataset = split_dataset['test']


## Mount Google Drive for local saving of the checkpoints

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Define training arguments and trainer module

In [17]:
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

training_args = TrainingArguments(
    output_dir= '/content/drive/MyDrive/model_weights',  # Directory for saving training outputs
    num_train_epochs=10,  # Number of training epochs
    per_device_train_batch_size=128,  # Batch size per device during training
    logging_dir='./logs',  # Directory for training logs
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    save_strategy="epoch",  # Save model checkpoints at the end of each epoch
    gradient_accumulation_steps=16,  # Number of gradient accumulation steps
    # save_total_limit=4,  # Limit the total number of checkpoints
    learning_rate=5e-4,  # Learning rate for training
    weight_decay=0.01,  # Weight decay for regularization
    gradient_checkpointing=False,  # Disable gradient checkpointing
    fp16=True,  # Use mixed precision training
    torch_compile=True,  # Use torch compilation for faster training
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,  # Use dynamic padding
)

###  I needed this to walkarround some A100 GPU warnings

In [12]:
import torch._dynamo
torch._dynamo.config.suppress_errors = True

## Fine-tune the model over our data

In [None]:
trainer.train()

## Inference function to use the fine-tuned model over a given word

In [14]:
import torch

def generate_modified_word(model, tokenizer, word, max_length=128):
    """
    Generate a modified word using the fine-tuned T5 model, moving computations to GPU if available.

    Args:
    - model: The fine-tuned T5 model.
    - tokenizer: The T5 tokenizer.
    - word: The input word to modify.
    - max_length: The maximum length of the output sequence.

    Returns:
    - The modified word as predicted by the model.
    """
    # Detect if GPU is available and move model to the appropriate device (GPU or CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Prepare the input text
    input_text =  word
    # Encode the input text, ensuring the tensors are on the same device as the model
    input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True).to(device)

    # Generate output tokens, ensuring the operation occurs on the same device
    output_tokens = model.generate(input_ids, max_length=max_length)

    # Decode the output tokens
    modified_word = tokenizer.decode(output_tokens[0], skip_special_tokens=True)

    return modified_word

## Overfit checking


In [19]:
word = dataset[0]["before"]  # The word you want to modify
gt_word = dataset[0]["after"]  # The ground truth word
modified_word = generate_modified_word(model, tokenizer, word)
print(f"Original: {word}, Modified: {modified_word}, GT: {gt_word}")

Original: spired, Modified: <pad> spibirebed, GT: spibirebed


## Test new word

In [22]:
word = "cat"  # The word you want to modify
gt_word = insert_b_after_vowels(word) # The ground truth word
modified_word = generate_modified_word(model, tokenizer, word)
print(f"Original: {word}, Modified: {modified_word}, GT: {gt_word}")

Original: cat, Modified: <pad> cabat, GT: cabat
