In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("SKNahin/bengali-transliteration-data")
train_test_split =dataset["train"].train_test_split(test_size=0.2, seed=42)
train = train_test_split["train"]
test = train_test_split["test"]

In [3]:
def preprocess_data(batch):
    inputs = [f"transliterate: {text}" for text in batch["bn"]]
    targets = batch["rm"]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length").input_ids
    model_inputs["labels"] = labels
    return model_inputs

In [6]:
model_name="t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_dataset = train.map(preprocess_data, batched=True)
val_dataset = test.map(preprocess_data, batched=True)

In [10]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

In [8]:
print(train_dataset[0])
print(val_dataset)

{'bn': 'এটা কোনো পোস্ট হলো মিয়া আবাল', 'rm': 'eta kono post holo mia abal', 'input_ids': [3017, 9842, 342, 10, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [3, 15, 17, 9, 10447, 32, 442, 3, 2831, 32, 1337, 9, 703, 138, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 