In [1]:
# from transformers import AutoTokenizer

# tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBART", src_lang="te_IN", use_fast=False)

# vocab = tokenizer.get_vocab()
# telugu_tokens = [tok for tok in vocab.keys() if any('\u0C00' <= ch <= '\u0C7F' for ch in tok)]

# print(telugu_tokens)

The IndicBART tokenizer is based on SentencePiece, but instead of standard subword units like BPE or WordPiece, it's trained at the Unicode character level, where each token is often a:
- Standalone consonant (క, గ, త, etc.)
- Vowel sign or diacritic (ా, ి, ీ, etc.)
- Word boundary marker (▁ for whitespace)

This is intentional for Indic scripts because:

- Indic languages are highly agglutinative, and subword segmentation can be noisy.
- It's better to model individual aksharas (syllables) or character+diacritic units instead of full words or arbitrary subwords.

In [6]:
# !pip install -r requirements.txt

In [25]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import os

from transformers import MT5Tokenizer, MT5ForConditionalGeneration
from torch.utils.data import Dataset, DataLoader

In [None]:
model_name = "google/mt5-small"  # or "t5-base", "t5-large"
tokenizer = MT5Tokenizer.from_pretrained(model_name) # loads the tokenizer from hugging face

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.mt5.tokenization_mt5.MT5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [33]:
SamantarData_DirCS = pd.read_csv(r"./Datasets/SamantarDatasetWithDirectConsonantSubstitutions.csv") # pure consonant replacement
SamantarData_DiaS = pd.read_csv(r"./Datasets/SamantarDatasetWithDiacriticSubstitutions.csv") # Diacritic substitution errors
WebScrapedData_DirCS = pd.read_csv(r"./Datasets/SamantarDatasetWithDirectConsonantSubstitutions.csv") # pure consonant replacement
WebScrapedData_DiaS = pd.read_csv(r"./Datasets/SamantarDatasetWithDiacriticSubstitutions.csv") # Diacritic substitution errors
print(SamantarData_DirCS.columns)

Index(['OriginalText', 'AverageErrorsPerWord_0.15%', '15%_ErrorInducedText',
       'AverageErrorsPerWord_0.25%', '25%_ErrorInducedText',
       'AverageErrorsPerWord_0.35%', '35%_ErrorInducedText',
       'AverageErrorsPerWord_0.5%', '50%_ErrorInducedText'],
      dtype='object')


In [None]:
error_percentages = ['15', '25', '35', '50']

SamantarData_DirCS_processed = pd.DataFrame()
SamantarData_DiaS_processed = pd.DataFrame()
WebScrapedData_DirCS_processed = pd.DataFrame()
WebScrapedData_DiaS_processed = pd.DataFrame()

# Process DirCS
for perc in error_percentages:
    df_Samantar = SamantarData_DirCS[['OriginalText', f'{perc}%_ErrorInducedText']].rename(
        columns={'OriginalText': 'input', f'{perc}%_ErrorInducedText': 'target'}
    )
    SamantarData_DirCS_processed[f'SamantarData_DirCS_{perc}p'] = df_Samantar.to_dict(orient='records')
    
    df_webscraped = WebScrapedData_DirCS[['OriginalText', f'{perc}%_ErrorInducedText']].rename(
        columns={'OriginalText': 'input', f'{perc}%_ErrorInducedText': 'target'}
    )
    WebScrapedData_DirCS_processed[f'WebScrapedData_DirCS_{perc}p'] = df_webscraped.to_dict(orient='records')

# Process DiaS
for perc in error_percentages:
    df_Samantar = SamantarData_DiaS[['OriginalText', f'{perc}%_ErrorInducedText']].rename(
        columns={'OriginalText': 'input', f'{perc}%_ErrorInducedText': 'target'}
    )
    SamantarData_DiaS_processed[f'SamantarData_DiaS_{perc}p'] = df_Samantar.to_dict(orient='records')
    
    df_webscraped = WebScrapedData_DiaS[['OriginalText', f'{perc}%_ErrorInducedText']].rename(
        columns={'OriginalText': 'input', f'{perc}%_ErrorInducedText': 'target'}
    )
    WebScrapedData_DiaS_processed[f'WebScrapedData_DiaS_{perc}p'] = df_webscraped.to_dict(orient='records')

In [53]:
def preprocess(text):
    input_enc = tokenizer(text["input"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    target_enc = tokenizer(text["target"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_enc["labels"] = target_enc["input_ids"]
    return input_enc

In [None]:
datasets = {
    "SamantarData_DirCS_15p": SamantarData_DirCS_processed['SamantarData_DirCS_15p'],
    "SamantarData_DirCS_25p": SamantarData_DirCS_processed['SamantarData_DirCS_25p'],
    "SamantarData_DirCS_35p": SamantarData_DirCS_processed['SamantarData_DirCS_35p'],
    "SamantarData_DirCS_50p": SamantarData_DirCS_processed['SamantarData_DirCS_50p'],
    
    "SamantarData_DiaS_15p": SamantarData_DiaS_processed['SamantarData_DiaS_15p'],
    "SamantarData_DiaS_25p": SamantarData_DiaS_processed['SamantarData_DiaS_25p'],
    "SamantarData_DiaS_35p": SamantarData_DiaS_processed['SamantarData_DiaS_35p'],
    "SamantarData_DiaS_50p": SamantarData_DiaS_processed['SamantarData_DiaS_50p'],
    
    "WebScrapedData_DirCS_15p": WebScrapedData_DirCS_processed['WebScrapedData_DirCS_15p'],
    "WebScrapedData_DirCS_25p": WebScrapedData_DirCS_processed['WebScrapedData_DirCS_25p'],
    "WebScrapedData_DirCS_35p": WebScrapedData_DirCS_processed['WebScrapedData_DirCS_35p'],
    "WebScrapedData_DirCS_50p": WebScrapedData_DirCS_processed['WebScrapedData_DirCS_50p'],
    
    "WebScrapedData_DiaS_15p": WebScrapedData_DiaS_processed['WebScrapedData_DiaS_15p'],
    "WebScrapedData_DiaS_25p": WebScrapedData_DiaS_processed['WebScrapedData_DiaS_25p'],
    "WebScrapedData_DiaS_35p": WebScrapedData_DiaS_processed['WebScrapedData_DiaS_35p'],
    "WebScrapedData_DiaS_50p": WebScrapedData_DiaS_processed['WebScrapedData_DiaS_50p'],
}

tokenized_datasets = {}

for name, dataset in datasets.items():
    tokenized_datasets[f"tokenized_{name}"] = [preprocess(item) for item in dataset]

In [55]:
class T5Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        item = {k: v.squeeze() for k, v in self.data[idx].items()}
        return item

    def __len__(self):
        return len(self.data)

In [None]:
tokenized_datasets = {
    
    # Samantar Data
    "SamantarData_DirCS_15p": [preprocess(item) for item in SamantarData_DirCS_processed['SamantarData_DirCS_15p']],
    "SamantarData_DirCS_25p": [preprocess(item) for item in SamantarData_DirCS_processed['SamantarData_DirCS_25p']],
    "SamantarData_DirCS_35p": [preprocess(item) for item in SamantarData_DirCS_processed['SamantarData_DirCS_35p']],
    "SamantarData_DirCS_50p": [preprocess(item) for item in SamantarData_DirCS_processed['SamantarData_DirCS_50p']],
    
    "SamantarData_DiaS_15p": [preprocess(item) for item in SamantarData_DiaS_processed['SamantarData_DiaS_15p']],
    "SamantarData_DiaS_25p": [preprocess(item) for item in SamantarData_DiaS_processed['SamantarData_DiaS_25p']],
    "SamantarData_DiaS_35p": [preprocess(item) for item in SamantarData_DiaS_processed['SamantarData_DiaS_35p']],
    "SamantarData_DiaS_50p": [preprocess(item) for item in SamantarData_DiaS_processed['SamantarData_DiaS_50p']],
    
    # WebScraped Data
    "WebScrapedData_DirCS_15p": [preprocess(item) for item in WebScrapedData_DirCS_processed['WebScrapedData_DirCS_15p']],
    "WebScrapedData_DirCS_25p": [preprocess(item) for item in WebScrapedData_DirCS_processed['WebScrapedData_DirCS_25p']],
    "WebScrapedData_DirCS_35p": [preprocess(item) for item in WebScrapedData_DirCS_processed['WebScrapedData_DirCS_35p']],
    "WebScrapedData_DirCS_50p": [preprocess(item) for item in WebScrapedData_DirCS_processed['WebScrapedData_DirCS_50p']],
    
    "WebScrapedData_DiaS_15p": [preprocess(item) for item in WebScrapedData_DiaS_processed['WebScrapedData_DiaS_15p']],
    "WebScrapedData_DiaS_25p": [preprocess(item) for item in WebScrapedData_DiaS_processed['WebScrapedData_DiaS_25p']],
    "WebScrapedData_DiaS_35p": [preprocess(item) for item in WebScrapedData_DiaS_processed['WebScrapedData_DiaS_35p']],
    "WebScrapedData_DiaS_50p": [preprocess(item) for item in WebScrapedData_DiaS_processed['WebScrapedData_DiaS_50p']],
}

In [58]:
dataset_loaders = {}

for name, tokenized_data in tokenized_datasets.items():
    dataset = T5Dataset(tokenized_data)
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    dataset_loaders[name] = loader

In [None]:
def train(dataset_loader, num_epochs, model, model_name="model"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    Model = model.to(device)

    optimizer = torch.optim.AdamW(Model.parameters(), lr=1e-4)

    losses = []

    Model.train()
    for epoch in range(num_epochs):
        for batch in dataset_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            losses.append(loss.item())

            print(f"Loss: {loss.item():.4f}")

    os.makedirs("models", exist_ok=True)
    os.makedirs("loss_plots", exist_ok=True)

    plt.figure()
    plt.plot(losses)
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title(f'Training Loss Curve: {model_name}')
    plt.grid(True)
    plt.savefig(f"loss_plots/{model_name}_loss_curve.png")   # Save figure
    plt.close()

    return model

In [None]:
dataloaders = {
    f"model_{name}": loader
    for name, loader in dataset_loaders.items()
}

trained_models = {}

model = MT5ForConditionalGeneration.from_pretrained(model_name) # load the weights of the model from hugging face

for model_name, loader in dataloaders.items():
    print(f"Training {model_name}...")
    trained_model = train(loader, num_epochs=10, model, model_name=model_name)
    trained_models[model_name] = model 
    torch.save(model.state_dict(), f"{model_name}.pt")
    print(f"Saved {model_name}.pt successfully.\n")

OSError: model_SamantarData_DirCS_15p is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`