In [65]:
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import pandas as pd

In [66]:
# Load the pre-trained tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Himanshu9192/Spelling_Checker")
model = AutoModelForSeq2SeqLM.from_pretrained("Himanshu9192/Spelling_Checker")


In [67]:
# Load your big.txt dataset
with open("dataset/big.txt", "r") as f:
    big_data = f.readlines()

In [68]:
# Preprocess the big.txt dataset
big_data = [sentence.strip() for sentence in big_data if sentence.strip()]  # Remove empty lines
train_dataset = pd.DataFrame({"input": big_data, "target": big_data})  # Use the same sentences as input and target

In [76]:
max_length = 128  # specify the maximum length for tokenization

In [85]:
def preprocess_dataset(row):
    input_sentence = row["input"]
    target_sentence = row["target"]
    
    # Tokenize the sentences and convert them to tensors
    tokenized_input = tokenizer.encode_plus(input_sentence, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
    tokenized_target = tokenizer.encode_plus(target_sentence, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
    
    input_ids = tokenized_input["input_ids"].squeeze(0)  # Remove the batch dimension if it exists
    attention_mask = tokenized_input["attention_mask"].squeeze(0)  # Remove the batch dimension if it exists
    labels = tokenized_target["input_ids"].squeeze(0)  # Remove the batch dimension if it exists
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

In [86]:
# Preprocess the entire dataset
preprocessed_data = [preprocess_dataset(row) for _, row in train_dataset.iterrows()]

In [87]:
# Convert the list of preprocessed data into a PyTorch dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

In [88]:
train_dataset_preprocessed = CustomDataset(preprocessed_data)

In [89]:
# Define the training arguments
training_args = TrainingArguments(
    per_device_train_batch_size=8,
    num_train_epochs=3,
    logging_dir='./logs',
    output_dir='./checkpoints',  
)


In [90]:
# Fine-tune the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset_preprocessed,
)

In [91]:
trainer.train()

  0%|          | 0/38814 [00:00<?, ?it/s]

KeyboardInterrupt: 