In [1]:
import torch
from torch.utils.data import Dataset
import os
import sentencepiece as spm

In [2]:
# LanguageDataset

from typing import Dict, List, Optional

class LanguageDataset(Dataset):   
    def __init__(self, raw_data):
        self.data = raw_data
        #self.s = spm.SentencePieceProcessor(model_file = pathToSPM + 'flores200_sacrebleu_tokenizer_spm.model')
 
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    
    def processData(src_text, trg_text, src_code = 0, max_length = 64):
        input_id = list(map(lambda x : x + 1, self.s.encode(src_text)[0]))
        input_id.append(2)
        if src_code != 0:
            input_id.append(src_code)
        padding = max_length - len(input_id)
        if padding > 0:
            for _ in range(padding):
                input_id.append(1)
        elif padding < 0:
            input_id = input_id[:62]
            input_id.append(2)
            if src_code != 0:
                input_id.append(src_code)
            else:
                input_id.append(1)
        input_id = torch.tensor(input_id)
        
        attention_mask = torch.tensor([[1] * len(input_id)])

        if trg_text is not None:
            labels = list(map(lambda x : x + 1, self.s.encode(trg_text)[0]))
            labels.append(2)
            labels.append(256161)
            padding = max_length - len(labels)
            if padding > 0:
                for _ in range(padding):
                    labels.append(1)
            elif padding < 0:
                labels = labels[:62]
                labels.append(2)
                labels.append(256161)
            labels = torch.tensor(labels)

    
    def load_raw_data(src_filepath: List[str], lang_code: List[str], model_name: str, 
                      trg_filepath: List[str]=None, max_length: int=256):
        len_src = len(src_filepath)
        text_data = {'src_text': []}
        token_data = []
        if len_src != len(lang_code):
            raise Exception("Lengths of src_filepath and lang_code don't match.")

        if trg_filepath:
            text_data['target_text'] = []
            if len_src != len(trg_filepath):
                raise Exception("Lengths of src_filepath and trg_filepath don't match.")
            is_trg = True

        for i in range(len_src):
            src_path = src_filepath[i]
            code = lang_code[i]
            tokenizer = AutoTokenizer.from_pretrained(model_name, src_lang=code, tgt_lang="spa_Latn")
            with open(src_path) as f:
                for line in f:
                    text_data['src_text'].append(line.strip())

            if is_trg:
                trg_path = trg_filepath[i] 
                with open(trg_path) as f:
                    for line in f:
                        text_data['target_text'].append(line.strip())

                for src_text, trg_text in zip(text_data['src_text'], text_data['target_text']):
                    token_data.append(tokenizer(src_text, text_target=trg_text, 
                                                max_length=max_length, padding='max_length', truncation=True))
            else:
                for src_text in text_data['src_text']:
                    token_data.append(tokenizer(src_text, max_length=max_length, padding='max_length', truncation=True))

        return token_data

In [3]:
from torch.utils.data import random_split,  DataLoader
import pytorch_lightning as pl
from typing import Optional
from utils import load_raw_data, predict
#from dataset import load_dataset
#import LanguageDataset

class LanugageDataModule(pl.LightningDataModule):
    """
    DataModule used for semantic segmentation in geometric generalization project
    """

    def __init__(self,data, eval, batch_size, numOfWorker = 8):
        self.data = data
        self.batch_size = batch_size
        self.model_name = model_name
        self.numOfWorker = numOfWorker
      
        
    def prepare_data(self):
        """
        Empty prepare_data method left in intentionally. 
        https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html#prepare-data
        """
        pass
    def setup(self, stage: Optional[str] = None):
        """
        Method to setup your datasets, here you can use whatever dataset class you have defined in Pytorch and prepare the data in order to pass it to the loaders later
        https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html#setup
        """
      
        # Assign train/val datasets for use in dataloaders
        # the stage is used in the Pytorch Lightning trainer method, which you can call as fit (training, evaluation) or test, also you can use it for predict, not implemented here
        
        if stage == "fit" or stage is None:
            train_set_full =  LanguageDataset(self.data)
            train_set_size = int(len(train_set_full) * 0.9)
            valid_set_size = len(train_set_full) - train_set_size
            self.train, self.validate = random_split(train_set_full, [train_set_size, valid_set_size])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test = LanguageDataset()
            
    # define your dataloaders
    # again, here defined for train, validate and test, not for predict as the project is not there yet. 
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True, num_workers=self.numOfWorker)

    def val_dataloader(self):
        return DataLoader(self.validate, batch_size=self.batch_size, num_workers=self.numOfWorker)

    def test_dataloader(self):
        return DataLoader(self.validate, batch_size=self.batch_size, num_workers=self.numOfWorker)

In [4]:
import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.translation import (
    TranslationTransformer,
    WMT16TranslationDataModule,
)
main_folder =  '../processed_data/'
model_name = "facebook/nllb-200-distilled-600M"
train_src_filepath = [main_folder+'ashaninka/dedup_filtered.cni',
              main_folder+'aymara/dedup_filtered.aym',
              main_folder+'bribri/dedup_filtered.bzd',
              main_folder+'guarani/dedup_filtered.gn',
              main_folder+'hñähñu/dedup_filtered.oto',
              main_folder+'nahuatl/dedup_filtered.nah',
              main_folder+'quechua/dedup_filtered.quy',
              main_folder+'raramuri/dedup_filtered.tar',
              main_folder+'shipibo_konibo/dedup_filtered.shp',
              main_folder+'wixarika/dedup_filtered.hch']

train_trg_filepath = [main_folder+'ashaninka/dedup_filtered.es',
                      main_folder+'aymara/dedup_filtered.es',
                      main_folder+'bribri/dedup_filtered.es',
                      main_folder+'guarani/dedup_filtered.es',
                      main_folder+'hñähñu/dedup_filtered.es',
                      main_folder+'nahuatl/dedup_filtered.es',
                      main_folder+'quechua/dedup_filtered.es',
                      main_folder+'raramuri/dedup_filtered.es',
                      main_folder+'shipibo_konibo/dedup_filtered.es',
                      main_folder+'wixarika/dedup_filtered.es']

eval_src_filepath = [main_folder+'ashaninka/dev.cni',
                     main_folder+'aymara/dev.aym',
                     main_folder+'bribri/dev.bzd',
                     main_folder+'guarani/dev.gn',
                     main_folder+'hñähñu/dev.oto',
                     main_folder+'nahuatl/dev.nah',
                     main_folder+'quechua/dev.quy',
                     main_folder+'raramuri/dev.tar',
                     main_folder+'shipibo_konibo/dev.shp',
                     main_folder+'wixarika/dev.hch']

eval_trg_filepath = [main_folder+'ashaninka/dev.es',
                     main_folder+'aymara/dev.es',
                     main_folder+'bribri/dev.es',
                     main_folder+'guarani/dev.es',
                     main_folder+'hñähñu/dev.es',
                     main_folder+'nahuatl/dev.es',
                     main_folder+'quechua/dev.es',
                     main_folder+'raramuri/dev.es',
                     main_folder+'shipibo_konibo/dev.es',
                     main_folder+'wixarika/dev.es']
lang_code = ['cni_Latn', 'aym_Latn', 'bzd_Latn', 'gn_Latn', 'oto_Latn', 
     'nah_Latn', 'quy_Latn', 'tar_Latn', 'shp_Latn', 'hch_Latn']

data = load_raw_data(train_src_filepath, lang_code, model_name, train_trg_filepath, max_length=128)
eval = load_raw_data(eval_src_filepath, lang_code, model_name, eval_trg_filepath, max_length=128)


model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
dm = LanugageDataModule(
    # WMT translation datasets: ['cs-en', 'de-en', 'fi-en', 'ro-en', 'ru-en', 'tr-en']
    data = data,
    eval = eval,
    batch_size = 32,

)
trainer = pl.Trainer(accelerator="gpu", devices="auto", max_epochs=1)

trainer.fit(model, dm)

KeyboardInterrupt: 