In [1]:
# from transformers import AutoTokenizer

# tokenizer = AutoTokenizer.from_pretrained("ai4bharat/IndicBART", src_lang="te_IN", use_fast=False)

# vocab = tokenizer.get_vocab()
# telugu_tokens = [tok for tok in vocab.keys() if any('\u0C00' <= ch <= '\u0C7F' for ch in tok)]

# print(telugu_tokens)

The IndicBART tokenizer is based on SentencePiece, but instead of standard subword units like BPE or WordPiece, it's trained at the Unicode character level, where each token is often a:
- Standalone consonant (క, గ, త, etc.)
- Vowel sign or diacritic (ా, ి, ీ, etc.)
- Word boundary marker (▁ for whitespace)

This is intentional for Indic scripts because:

- Indic languages are highly agglutinative, and subword segmentation can be noisy.
- It's better to model individual aksharas (syllables) or character+diacritic units instead of full words or arbitrary subwords.

In [2]:
# !pip install -r requirements.txt

In [3]:
import torch

In [4]:
import pandas as pd
import torch
import matplotlib.pyplot as plt
import os

from transformers import MT5Tokenizer, MT5ForConditionalGeneration
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# !pip install sentencepiece

In [6]:
model_name = "google/mt5-small"  # or any other model name you want to use
tokenizer = MT5Tokenizer.from_pretrained(model_name) # loads the tokenizer from hugging face

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.mt5.tokenization_mt5.MT5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [7]:
SamantarData_DirCS = pd.read_csv(r"./Datasets/SamantarDatasetWithDirectConsonantSubstitutions.csv") # pure consonant replacement
SamantarData_DiaS = pd.read_csv(r"./Datasets/SamantarDatasetWithStackedSubstitutions.csv") # Diacritic substitution errors
WebScrapedData_DirCS = pd.read_csv(r"./Datasets/SamantarDatasetWithDirectConsonantSubstitutions.csv") # pure consonant replacement
WebScrapedData_DiaS = pd.read_csv(r"./Datasets/SamantarDatasetWithStackedSubstitutions.csv") # Diacritic substitution errors
print(SamantarData_DiaS.columns)

Index(['OriginalText', '15%_ErrorInducedText', 'Percentage_ErrorInduced_15%',
       'AverageErrorPerWord_15%', '25%_ErrorInducedText',
       'Percentage_ErrorInduced_25%', 'AverageErrorPerWord_25%',
       '35%_ErrorInducedText', 'Percentage_ErrorInduced_35%',
       'AverageErrorPerWord_35%', '50%_ErrorInducedText',
       'Percentage_ErrorInduced_50%', 'AverageErrorPerWord_50%'],
      dtype='object')


In [8]:
error_percentages = ['15', '25', '35', '50']

SamantarData_DirCS_processed = {}
SamantarData_DiaS_processed = {}
WebScrapedData_DirCS_processed = {}
WebScrapedData_DiaS_processed = {}

# Filter out the row sentences with more than 2 errors in the average error per word
# and create a new DataFrame with the required columns for each error percentage
for perc in error_percentages:

    avg_col = f'AverageErrorPerWord_{perc}%'

    df_Samantar = SamantarData_DirCS[(SamantarData_DirCS[avg_col] > 2)][['OriginalText', f'{perc}%_ErrorInducedText']].rename(columns={'OriginalText': 'input', f'{perc}%_ErrorInducedText': 'target'})
    
    SamantarData_DirCS_processed[f'SamantarData_DirCS_{perc}p'] = df_Samantar.to_dict(orient='records')
    
    df_webscraped = WebScrapedData_DirCS[(WebScrapedData_DirCS[avg_col] > 2)][['OriginalText', f'{perc}%_ErrorInducedText']].rename(columns={'OriginalText': 'input', f'{perc}%_ErrorInducedText': 'target'})

    WebScrapedData_DirCS_processed[f'WebScrapedData_DirCS_{perc}p'] = df_webscraped.to_dict(orient='records')

# Process DiaS
for perc in error_percentages:
    
    avg_col = f'AverageErrorPerWord_{perc}%'

    df_Samantar = SamantarData_DiaS[(SamantarData_DiaS[avg_col] > 2)][['OriginalText', f'{perc}%_ErrorInducedText']].rename(columns={'OriginalText': 'input', f'{perc}%_ErrorInducedText': 'target'})

    SamantarData_DiaS_processed[f'SamantarData_DiaS_{perc}p'] = df_Samantar.to_dict(orient='records')
    
    df_webscraped = WebScrapedData_DiaS[(WebScrapedData_DiaS[avg_col] > 2)][['OriginalText', f'{perc}%_ErrorInducedText']].rename(columns={'OriginalText': 'input', f'{perc}%_ErrorInducedText': 'target'})

    WebScrapedData_DiaS_processed[f'WebScrapedData_DiaS_{perc}p'] = df_webscraped.to_dict(orient='records')

In [9]:
def preprocess(text):
    input_enc = tokenizer(text["input"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    target_enc = tokenizer(text["target"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_enc["labels"] = target_enc["input_ids"]
    return input_enc

In [10]:
datasets = {
    "SamantarData_DirCS_15p": SamantarData_DirCS_processed['SamantarData_DirCS_15p'],
    "SamantarData_DirCS_25p": SamantarData_DirCS_processed['SamantarData_DirCS_25p'],
    "SamantarData_DirCS_35p": SamantarData_DirCS_processed['SamantarData_DirCS_35p'],
    "SamantarData_DirCS_50p": SamantarData_DirCS_processed['SamantarData_DirCS_50p'],
    
    "SamantarData_DiaS_15p": SamantarData_DiaS_processed['SamantarData_DiaS_15p'],
    "SamantarData_DiaS_25p": SamantarData_DiaS_processed['SamantarData_DiaS_25p'],
    "SamantarData_DiaS_35p": SamantarData_DiaS_processed['SamantarData_DiaS_35p'],
    "SamantarData_DiaS_50p": SamantarData_DiaS_processed['SamantarData_DiaS_50p'],
    
    "WebScrapedData_DirCS_15p": WebScrapedData_DirCS_processed['WebScrapedData_DirCS_15p'],
    "WebScrapedData_DirCS_25p": WebScrapedData_DirCS_processed['WebScrapedData_DirCS_25p'],
    "WebScrapedData_DirCS_35p": WebScrapedData_DirCS_processed['WebScrapedData_DirCS_35p'],
    "WebScrapedData_DirCS_50p": WebScrapedData_DirCS_processed['WebScrapedData_DirCS_50p'],
    
    "WebScrapedData_DiaS_15p": WebScrapedData_DiaS_processed['WebScrapedData_DiaS_15p'],
    "WebScrapedData_DiaS_25p": WebScrapedData_DiaS_processed['WebScrapedData_DiaS_25p'],
    "WebScrapedData_DiaS_35p": WebScrapedData_DiaS_processed['WebScrapedData_DiaS_35p'],
    "WebScrapedData_DiaS_50p": WebScrapedData_DiaS_processed['WebScrapedData_DiaS_50p'],
}

tokenized_datasets = {}

for name, dataset in datasets.items():
    tokenized_datasets[f"tokenized_{name}"] = [preprocess(item) for item in dataset]

In [11]:
print("Shape of tokenized embedding : ", tokenized_datasets["tokenized_SamantarData_DirCS_15p"][0]["attention_mask"].shape)
print("Shape of tokenized embedding : ", tokenized_datasets["tokenized_SamantarData_DiaS_15p"][0]["attention_mask"].shape)
print("Shape of tokenized embedding : ", tokenized_datasets["tokenized_WebScrapedData_DirCS_15p"][0]["attention_mask"].shape)

Shape of tokenized embedding :  torch.Size([1, 512])
Shape of tokenized embedding :  torch.Size([1, 512])
Shape of tokenized embedding :  torch.Size([1, 512])


In [12]:
class T5Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        item = {k: v.squeeze() for k, v in self.data[idx].items()}
        return item

    def __len__(self):
        return len(self.data)

In [13]:
dataset_loaders = {}

for name, tokenized_data in tokenized_datasets.items():
    dataset = T5Dataset(tokenized_data)
    loader = DataLoader(dataset, batch_size=5, shuffle=True)
    dataset_loaders[name] = loader

In [14]:
def train(dataset_loader, num_epochs, model, model_name="model", writer=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    Model = model.to(device)

    optimizer = torch.optim.AdamW(Model.parameters(), lr=1e-4)

    losses = []

    Model.train()
    global_step = 1
    for epoch in range(num_epochs):
        print(f"Epoch: {epoch} started")
        epoch_loss = 0
        for batch in dataset_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_loss += loss.item()

            writer.add_scalar("Loss/train_batch", loss.item(), global_step)

            losses.append(loss.item())
            global_step += 1

        avg_epoch_loss = epoch_loss / len(dataset_loader)
        writer.add_scalar("Loss/train_epoch", avg_epoch_loss, epoch)

        if(epoch % 5 == 0):
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    os.makedirs("models", exist_ok=True)
    os.makedirs("loss_plots", exist_ok=True)

    plt.figure()
    plt.plot(losses)
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.title(f'Training Loss Curve: {model_name}')
    plt.grid(True)
    plt.savefig(f"loss_plots/{model_name}_loss_curve.png")   # Save figure
    plt.close()

    return model

In [19]:
del model

In [20]:
from torch.utils.tensorboard import SummaryWriter

train_dataloader = dataset_loaders["tokenized_WebScrapedData_DiaS_25p"]
run_name = "WebScrapedData_DiaS_25p"
model = MT5ForConditionalGeneration.from_pretrained(model_name)
writer = SummaryWriter(log_dir=f"runs/{run_name}")

trained_model = train(train_dataloader, num_epochs=20, model=model, model_name=run_name, writer=writer)
writer.close()

# Save the model to a directory
trained_model.save_pretrained(f"models/{run_name}")

Epoch: 0 started
Epoch 0, Loss: 7.0842
Epoch: 1 started
Epoch: 2 started
Epoch: 3 started
Epoch: 4 started
Epoch: 5 started
Epoch 5, Loss: 0.5608
Epoch: 6 started
Epoch: 7 started
Epoch: 8 started
Epoch: 9 started
Epoch: 10 started
Epoch 10, Loss: 0.4776
Epoch: 11 started
Epoch: 12 started
Epoch: 13 started
Epoch: 14 started
Epoch: 15 started
Epoch 15, Loss: 0.4959
Epoch: 16 started
Epoch: 17 started
Epoch: 18 started
Epoch: 19 started


SafetensorError: Error while serializing: IoError(Os { code: 28, kind: StorageFull, message: "No space left on device" })

In [16]:
import gc
import torch

gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

In [17]:
model_name = "google/mt5-small"  # or "t5-base", "t5-large"

In [18]:
torch.cuda.empty_cache()  # Clear freed GPU memory

In [22]:
import torch
torch.cuda.memory_summary(device=None, abbreviated=False)  # Print memory summary



In [None]:
# trained_model.save_pretrained(f"models/{run_name}")