In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# Load the dataset
file_path = 'path/to/your/training_data.csv'  # Update this with the actual path to your CSV file
data = pd.read_csv(file_path)

# Split the data into training and validation sets
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

# Initialize the tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small')

# Custom Dataset class
class RSTDataset(Dataset):
    def __init__(self, data, tokenizer, source_max_len=512, target_max_len=512):
        self.data = data
        self.tokenizer = tokenizer
        self.source_max_len = source_max_len
        self.target_max_len = target_max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        source_text = str(self.data.iloc[idx]['feature'])
        target_text = str(self.data.iloc[idx]['actual_val'])

        # Tokenize the source and target texts
        source_encoding = self.tokenizer(
            source_text,
            max_length=self.source_max_len,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        target_encoding = self.tokenizer(
            target_text,
            max_length=self.target_max_len,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )

        labels = target_encoding['input_ids']
        labels[labels == self.tokenizer.pad_token_id] = -100  # Ignore padding tokens in the loss

        return {
            'input_ids': source_encoding['input_ids'].flatten(),
            'attention_mask': source_encoding['attention_mask'].flatten(),
            'labels': labels.flatten()
        }

# Create DataLoader for training and validation
train_dataset = RSTDataset(train_data, tokenizer)
val_dataset = RSTDataset(val_data, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)

# Initialize the T5 model
model = T5ForConditionalGeneration.from_pretrained('t5-small')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training function
def train_epoch(model, data_loader, optimizer, device, scheduler=None):
    model = model.train()
    losses = []
    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        losses.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if scheduler:
            scheduler.step()

    return sum(losses) / len(losses)

# Evaluation function
def eval_model(model, data_loader, device):
    model = model.eval()
    losses = []
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            losses.append(loss.item())

    return sum(losses) / len(losses)

# Training loop
epochs = 3
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    train_loss = train_epoch(model, train_loader, optimizer, device)
    print(f"Train loss: {train_loss}")

    val_loss = eval_model(model, val_loader, device)
    print(f"Validation loss: {val_loss}")

# Save the model
model.save_pretrained('rst_corrector_model')
tokenizer.save_pretrained('rst_corrector_model')


In [1]:
input_text = '   ==============\nreStructuredText\n==============\n\n**reStructuredText (reST)** is a lightweight markup language. Here’s a demonstration of various reST features\n\nSections    and  Subsections\n-------------------------\n\n    You can create sections and subsections using different underlines:\n   \n   Introduction\n~~~~~~~~~~~~\n\nThis is the introduction section.\n\nLists\n-----\n\nBullet List Item 1\nBullet List Item 2\n    - Nested Bullet List Item\n    -    Another   Nested  Item\n\n1. Numbered List Item 1\n2. Numbered List Item 2\n1. Nested Numbered Item 1\n    2.   Nested   Numbered  Item   2\n\n Links and References\n--------------------\n\n- `reStructuredText Documentation <https://docutils.sourceforge.io/rst.html>`_\nInternal reference to the `Lists`_ section.\n    \nInline  Markup\n-------------\n  \nThis is **bold text**, *italic text*, and ``inline literal text``. Here’s a `link <https://www.example.com>`_..\n\nImages\n------\n\n..  image::    https://via.placeholder.com/150\n    :alt:    Placeholder   Image\n:width: 100px\n  \nBlock Quotes\n\n  \nHere’s a block quote\n\n"To be, or not to be, that is the question." -- William Shakespeare\n\nLiteral Blocks\n --------------\n\nYou can include literal blocks by indenting:\n\n\n\n   def hello_world()\n         print("Hello,    World!")\n\nTables\n\n\nHere’s a simple table\n   \n+-----------------+-----------------+\n| Header 1        | Header 2        |\n+=================+=================+\n| Row 1, Column 1 | Row 1, Column 2 |\n+-----------------+-----------------+\n| Row 2, Column 1 | Row 2, Column 2 |\n+-----------------+-----------------+\n\nDirectives\n-\n\n.. note::\n\n    This is a note directive.\n   \n.. warning::\n\n   This is a warning directive.\n\n.. code-block: python\n\n   def hello_world()\n        print("Hello,    World!")\n\nAdmonitions\n-----------\n \n.. admonition:: hint\n\n   reStructuredText is powerful yet simple!\n \nFootnotes\n-----\n\nYou can add footnotes like this: [#]_..\n\n.. [#] this is the footnote.\n\nDefinition Lists\n----------------\n\n  Python\n  A high-level programming language.\n\nreStructuredText\nA lightweight markup language.\n\nSphinx Domains\n  --------------\n \nSphinx domains are used for documenting code\n    \n.. py:function:: my_function(arg1, arg2)\n\n       This is a Python function.\n\n.. c:function:: int my_c_function(int arg1, int arg2)\n\n   This is a C function.\n\nConclusion\n----------\n\nThis reStructuredText document showcases various features you can use to create structured and well-formatted documentation.\n\n   \n\n    \n'

In [None]:
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)

print(input_ids.device, attention_mask.device, labels.device)

imput_ids = input_ids.to(device)

# Generate the corrected output
output_ids = model.generate(input_ids, max_length=512, num_beams=4, early_stopping=True)

# Decode the output to get the corrected text
corrected_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

# Print the corrected text
print(corrected_text)