In [63]:
from datasets import load_dataset
import pandas as pd
import pytorch_lightning as pl
from transformers import DataCollatorForSeq2Seq, AutoTokenizer
from torch.utils.data import DataLoader
import os
import pickle
from datasets import load_dataset

In [2]:
dataset = load_dataset("facebook/flores", "ell_Grek")

In [3]:
dataset

DatasetDict({
    dev: Dataset({
        features: ['id', 'URL', 'domain', 'topic', 'has_image', 'has_hyperlink', 'sentence'],
        num_rows: 997
    })
    devtest: Dataset({
        features: ['id', 'URL', 'domain', 'topic', 'has_image', 'has_hyperlink', 'sentence'],
        num_rows: 1012
    })
})

In [4]:
df = pd.DataFrame(dataset['dev'])
df_dev = pd.DataFrame(dataset['devtest'])
df.iloc[0]['sentence']

'Τη Δευτέρα, επιστήμονες από την Ιατρική Σχολή του Πανεπιστημίου του Στάνφορντ ανακοίνωσαν την εφεύρεση ενός νέου εργαλείου διάγνωσης με δυνατότητα ομαδοποίησης των κυττάρων ανά τύπο: ένα μικροσκοπικό εκτυπώσιμο τσιπ που μπορεί να κατασκευαστεί με απλούς εκτυπωτές ψεκασμού μελάνης με κόστος περίπου ένα σεντ του αμερικανικού δολαρίου το καθένα.'

In [5]:
dataset_eng = load_dataset("facebook/flores", "all")
df_eng = pd.DataFrame(dataset_eng['dev'])

In [6]:
dataset_eng

DatasetDict({
    dev: Dataset({
        features: ['id', 'URL', 'domain', 'topic', 'has_image', 'has_hyperlink', 'sentence_ace_Arab', 'sentence_bam_Latn', 'sentence_dzo_Tibt', 'sentence_hin_Deva', 'sentence_khm_Khmr', 'sentence_mag_Deva', 'sentence_pap_Latn', 'sentence_sot_Latn', 'sentence_tur_Latn', 'sentence_ace_Latn', 'sentence_ban_Latn', 'sentence_ell_Grek', 'sentence_hne_Deva', 'sentence_kik_Latn', 'sentence_mai_Deva', 'sentence_pbt_Arab', 'sentence_spa_Latn', 'sentence_twi_Latn', 'sentence_acm_Arab', 'sentence_bel_Cyrl', 'sentence_eng_Latn', 'sentence_hrv_Latn', 'sentence_kin_Latn', 'sentence_mal_Mlym', 'sentence_pes_Arab', 'sentence_srd_Latn', 'sentence_tzm_Tfng', 'sentence_acq_Arab', 'sentence_bem_Latn', 'sentence_epo_Latn', 'sentence_hun_Latn', 'sentence_kir_Cyrl', 'sentence_mar_Deva', 'sentence_plt_Latn', 'sentence_srp_Cyrl', 'sentence_uig_Arab', 'sentence_aeb_Arab', 'sentence_ben_Beng', 'sentence_est_Latn', 'sentence_hye_Armn', 'sentence_kmb_Latn', 'sentence_min_Arab', 'sen

In [7]:
dataset_all = load_dataset("facebook/flores", "all")

In [8]:
language_to_choose = ["eng_Latn", "deu_Latn", "fra_Latn", "ron_Latn"]
my_list = []
for lang in language_to_choose:
    my_list.append(dataset_all['dev'][f'sentence_{lang}'])

In [9]:
my_list[3]

['Luni, oameni de știință de la Facultatea de Medicină a Universității Stanford au anunțat inventarea unui nou instrument de diagnosticare, care poate sorta celulele în funcție de tipul lor: un cip minuscul, printabil, care poate fi produs folosind imprimante obișnuite cu jet de cerneală, pentru aproximativ un cent american bucata.',
 'Cercetătorii principali spun că acest lucru ar putea duce la detectarea precoce a bolilor precum cancerul, tuberculoza, SIDA sau malaria la pacienții din țări cu venituri mici, unde ratele de supraviețuire în boli precum cancerul de sân pot fi de două ori mai mici decât în țările bogate.',
 'Aparatul JAS 39C Gripen s-a prăbușit pe o pistă în jurul orei locale 9:20 a.m. (0230 UTC) și a explodat, ducând la închiderea aeroportului pentru zboruri comerciale.',
 'Pilotul a fost identificat ca fiind liderul escadrilei Dilokrit Pattavee.',
 'Presa locală raportează că un vehicul de pompieri din aeroport s-a răsturnat în timp ce se ducea spre o urgență.',
 'Vida

In [10]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "google-t5/t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [11]:
model.config.task_specific_params

{'summarization': {'early_stopping': True,
  'length_penalty': 2.0,
  'max_length': 200,
  'min_length': 30,
  'no_repeat_ngram_size': 3,
  'num_beams': 4,
  'prefix': 'summarize: '},
 'translation_en_to_de': {'early_stopping': True,
  'max_length': 300,
  'num_beams': 4,
  'prefix': 'translate English to German: '},
 'translation_en_to_fr': {'early_stopping': True,
  'max_length': 300,
  'num_beams': 4,
  'prefix': 'translate English to French: '},
 'translation_en_to_ro': {'early_stopping': True,
  'max_length': 300,
  'num_beams': 4,
  'prefix': 'translate English to Romanian: '}}

In [12]:
'english'.capitalize()

'English'

# Testing

In [246]:
class T5TranslationDataModule(pl.LightningDataModule):
    def __init__(self, model_name, dataset_name, max_length, 
                 batch_size, train_range, val_range, test_range, seed_num,
                 languages):
        super().__init__()
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.max_length = max_length
        self.batch_size = batch_size
        self.train_range = train_range
        self.val_range = val_range
        self.test_range = test_range
        self.seed_num = seed_num
        self.languages = languages
        self.tokenizer = None
        self.data_collator = None
        self.train_datasets = []
        self.val_datasets = []
        self.test_datasets = []
        self.cache_dir = f"./dataset_cache_{self.seed_num}"

    def prepare_data(self):
        load_dataset(self.dataset_name, 'all').shuffle(seed=self.seed_num)
        AutoTokenizer.from_pretrained(self.model_name)

    def setup(self, stage=None):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.data_collator = DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model_name)
        
        if stage == 'fit' or stage is None:
            self.train_datasets = self._get_or_process_dataset('train')
            self.val_datasets = self._get_or_process_dataset('validation')
        if stage == 'test' or stage is None:
            self.test_datasets = self._get_or_process_dataset('test')
        
        print(f"Setup complete. Datasets sizes: Train: {len(self.train_datasets)}, Val: {len(self.val_datasets)}, Test: {len(self.test_datasets)}")

    def _get_or_process_dataset(self, split):
        # Create combined dataset from all language pairs
        combined_dataset = []
        
        for language in self.languages:
            cache_file = os.path.join(self.cache_dir, f"{split}_{language}_{self.seed_num}.pkl")
            
            if os.path.exists(cache_file):
                print(f"Loading cached {split} dataset for {language}...")
                with open(cache_file, 'rb') as f:
                    dataset = pickle.load(f)
                print(f"Loaded {split} dataset for {language} with {len(dataset)} samples")
            else:
                print(f"Processing {split} dataset for {language}...")
                dataset = load_dataset(self.dataset_name, 'all')['dev'].shuffle(seed=self.seed_num)
                
                if split == 'train':
                    data = dataset.select(range(min(self.train_range, len(dataset))))
                elif split == 'validation':
                    data = dataset.select(range(min(self.val_range, len(dataset))))
                elif split == 'test':
                    data = dataset.select(range(min(self.test_range, len(dataset))))
                
                processed_dataset = self._preprocess_dataset(data, language)
                
                os.makedirs(self.cache_dir, exist_ok=True)
                with open(cache_file, 'wb') as f:
                    pickle.dump(processed_dataset, f)
                
                dataset = processed_dataset
            
            combined_dataset.extend(dataset)
        
        return combined_dataset
    
    def _preprocess_dataset(self, dataset, target_language):
        mapping = {
            "deu_Latn": "German",
            "fra_Latn": "French",
            "ron_Latn": "Romanian"
        }
        target_lang_name = mapping[target_language]

        def preprocess_function(examples):
            model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}

            for i in range(len(examples['sentence_eng_Latn'])):
                prefix = f"translate English to {target_lang_name}: "
                input_text = prefix + examples['sentence_eng_Latn'][i]
                target_text = examples[f'sentence_{target_language}'][i]
                
                tokenized_input = self.tokenizer(input_text, max_length=self.max_length, padding="max_length", truncation=True)
                tokenized_target = self.tokenizer(target_text, max_length=self.max_length, padding="max_length", truncation=True)
                
                model_inputs["input_ids"].append(tokenized_input["input_ids"])
                model_inputs["attention_mask"].append(tokenized_input["attention_mask"])
                model_inputs["labels"].append(tokenized_target["input_ids"])

            return model_inputs

        return dataset.map(
            preprocess_function,
            batched=True,
            remove_columns=dataset.column_names
        )

    def train_dataloader(self):
        return DataLoader(self.train_datasets, batch_size=self.batch_size, collate_fn=self.data_collator, shuffle=True, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_datasets, batch_size=self.batch_size, collate_fn=self.data_collator, drop_last=True)

    def test_dataloader(self):
        return DataLoader(self.test_datasets, batch_size=self.batch_size, collate_fn=self.data_collator, drop_last=True)

In [273]:
model_name = "t5-small"  # or your specific model name
dataset_name = "facebook/flores"  # or your specific dataset name
max_length = 128
batch_size = 32
train_range = 100 # adjust as needed
val_range = 100  # adjust as needed
test_range = 100  # adjust as needed
seed_num = 42
languages = ["deu_Latn", "fra_Latn", "ron_Latn"]

data_module = T5TranslationDataModule(
    model_name, dataset_name, max_length, batch_size,
    train_range, val_range, test_range, seed_num, languages
)

In [274]:
data_module.setup(stage='fit')

Processing train dataset for deu_Latn...


Map: 100%|██████████| 100/100 [00:00<00:00, 1216.82 examples/s]


Processing train dataset for fra_Latn...


Map: 100%|██████████| 100/100 [00:00<00:00, 1381.29 examples/s]


Processing train dataset for ron_Latn...


Map: 100%|██████████| 100/100 [00:00<00:00, 1327.51 examples/s]


Processing validation dataset for deu_Latn...


Map: 100%|██████████| 100/100 [00:00<00:00, 1542.97 examples/s]


Processing validation dataset for fra_Latn...


Map: 100%|██████████| 100/100 [00:00<00:00, 1470.87 examples/s]


Processing validation dataset for ron_Latn...


Map: 100%|██████████| 100/100 [00:00<00:00, 1543.81 examples/s]

Setup complete. Datasets sizes: Train: 300, Val: 300, Test: 0





In [275]:
list_with_english = []
list_with_translated = []
for i in range(300):
    result_of_iterator = next(iter(data_module.train_dataloader()))
    list_with_english.append(" ".join(data_module.tokenizer.batch_decode(result_of_iterator['input_ids'][0])))
    list_with_translated.append(" ".join(data_module.tokenizer.batch_decode(result_of_iterator['labels'][0])))


In [290]:
df = pd.DataFrame({"English": list_with_english, "Translated": list_with_translated})
df['English'].apply(lambda x: x.split(" : ")[1]).unique()

array(["A hostel collapse d in Me cca , the holy city of Islam at about 10  o ' clock this morning local time . </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>",
       'O cca s ional specialist air tours go  inland , for mountain e er ing or to reach the Pole , which has  a large base . </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>