In [3]:
import pandas as pd
from torch.utils.data import DataLoader,Dataset
from transformers import  MT5ForConditionalGeneration,T5Tokenizer
from transformers import DataCollatorForSeq2Seq
from sklearn.model_selection import train_test_split

In [4]:
class TranslationDataset(Dataset):

    def __init__(self,data_frame,tokenizer_name='google/mt5-base') -> None:
        super().__init__()
        self.dataframe = data_frame
        self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_name)

    def __len__(self):
        return self.dataframe.shape[0]
    def __getitem__(self, index):
        tokenized_data = self.tokenizer(self.dataframe.iloc[index,0],text_target=self.dataframe.iloc[index,1], return_tensors="pt")
        
        tokenized_data['input_ids'] = tokenized_data['input_ids'].squeeze(0)
        tokenized_data['attention_mask'] = tokenized_data['attention_mask'].squeeze(0)
        tokenized_data['labels'] = tokenized_data['labels'].squeeze(0)
        return tokenized_data


In [5]:

tokenizer_name = "google/mt5-base"
translation_model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
t5_tokenizer =  T5Tokenizer.from_pretrained(tokenizer_name)

data_collator = DataCollatorForSeq2Seq(t5_tokenizer, model=translation_model, return_tensors="pt")

In [6]:
def collate_fn(batch_data):
    return data_collator(batch_data)

In [7]:
data_frame = pd.read_csv("./data/EN-DE.txt", sep='\t',header=0, names=['src', 'trg', 'c1','c2','c3','c4','c5','c6'])[:100]
train_df , valid_df = train_test_split(data_frame,test_size=0.07)

In [8]:
train_dataset = TranslationDataset(train_df)
valid_dataset = TranslationDataset(valid_df)

BATCH_SIZE = 6
train_dataloader = DataLoader(train_dataset,BATCH_SIZE,collate_fn=collate_fn,shuffle=True)
valid_dataloader = DataLoader(valid_dataset,BATCH_SIZE,collate_fn=collate_fn,shuffle=False)


In [15]:
batch_data = next(iter(train_dataloader))
outs = translation_model(**batch_data)