In [45]:
from glob import glob 
from transformers import MT5ForConditionalGeneration, EncoderDecoderModel, AutoTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback, DataCollatorForSeq2Seq, Text2TextGenerationPipeline
from torch.utils.data import Dataset
from datasets import load_dataset
import regex as re
import numpy as np
import random 
from sacrebleu import BLEU 
import tqdm

In [2]:
bleu = BLEU()

In [3]:
class Text2TextDataset(Dataset):
    def __init__(self, inputs):
        self.inputs = inputs
        #self.targets = targets
        #self.tokenizer = tokenizer

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        input_text = self.inputs[index]
        #target_text = self.targets[index]
        return input_text 
        # input_encoding = self.tokenizer.encode_plus(
        #     input_text,
        #     max_length=400,
        #     padding="max_length",
        #     truncation=True,
        #     return_tensors="pt"
        # )
        # target_encoding = self.tokenizer.encode_plus(
        #     target_text,
        #     max_length=400,
        #     padding="max_length",
        #     truncation=True,
        #     return_tensors="pt"
        # )

        # input_ids = input_encoding["input_ids"].squeeze()
        # attention_mask = input_encoding["attention_mask"].squeeze()
        # labels = target_encoding["input_ids"].squeeze()

        # return {
        #     "input_ids": input_ids,
        #     "attention_mask": attention_mask,
        #     "labels": labels
        # }

In [4]:
lang_codes = {
    "cy": "Welsh",
    "br": "Breton",
    "ga": "Irish",
    "mt": "Maltese",
    "ru": "Russian",
    "de": "German",
    "en": "English"
}

In [5]:
model_name = "google/mt5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)



In [6]:
model_names = ['/scratch/aditya_hari/gsoc/mt5_pro/', '/scratch/aditya_hari/gsoc/mt5_denoised', '/scratch/aditya_hari/gsoc/mt5-iter-final/']

In [37]:
lang_data = {}
for lang in ['en', 'de', 'ga']:
    eval_srcs = [] 
    eval_tgts = []
    eval_src = open(f'/home2/aditya_hari/gsoc/data/processed/{lang}/eval_src', 'r').readlines()
    eval_tgt = open(f'/home2/aditya_hari/gsoc/data/processed/{lang}/eval_tgt', 'r').readlines()
    eval_srcs.extend([re.sub(r"[ ]{2,}", " " , f"generate {lang_codes[lang]} : {line}").strip() for line in eval_src])
    eval_tgts.extend([line.strip() for line in eval_tgt])
    lang_data[lang] = [eval_srcs, eval_tgts]

In [40]:
outputs = {}
for model_name in model_names:
    pipe = Text2TextGenerationPipeline(model=MT5ForConditionalGeneration.from_pretrained(model_name), tokenizer=tokenizer, batch_size=32, device=0, num_beams=5, early_stopping=True)
    for lang in ['en', 'de', 'ga']:
        print(model_name, lang)
        pb = tqdm.tqdm(total=len(lang_data[lang][0]))
        if(model_name not in outputs):
            outputs[model_name] = {}
        outs = [] 
        for out in pipe(Text2TextDataset(lang_data[lang][0])):
            pb.update(1)
            gen_texts = [i['generated_text'] for i in out]
            outs.extend(gen_texts)
        outputs[model_name][lang] = outs

/scratch/aditya_hari/gsoc/mt5_pro/ en


 22%|██▏       | 192/869 [00:39<02:21,  4.80it/s]


/scratch/aditya_hari/gsoc/mt5_pro/ de


100%|██████████| 1618/1618 [00:31<00:00, 51.14it/s]
100%|█████████▉| 865/869 [00:17<00:00, 57.53it/s]

/scratch/aditya_hari/gsoc/mt5_pro/ ga


100%|██████████| 869/869 [00:17<00:00, 50.66it/s]


/scratch/aditya_hari/gsoc/mt5_denoised en


100%|██████████| 1665/1665 [00:40<00:00, 41.22it/s]
 99%|█████████▉| 1601/1618 [00:31<00:00, 52.84it/s]

/scratch/aditya_hari/gsoc/mt5_denoised de


100%|██████████| 1618/1618 [00:31<00:00, 51.51it/s]


/scratch/aditya_hari/gsoc/mt5_denoised ga


100%|██████████| 869/869 [00:17<00:00, 50.35it/s]
100%|██████████| 1665/1665 [00:46<00:00, 58.91it/s]

/scratch/aditya_hari/gsoc/mt5-iter-final/ en


100%|██████████| 1665/1665 [00:49<00:00, 33.47it/s]


/scratch/aditya_hari/gsoc/mt5-iter-final/ de


100%|██████████| 1618/1618 [00:31<00:00, 51.03it/s]
100%|█████████▉| 865/869 [00:17<00:00, 57.38it/s]

/scratch/aditya_hari/gsoc/mt5-iter-final/ ga


100%|██████████| 869/869 [00:17<00:00, 50.44it/s]


In [41]:
outputs[list(outputs.keys())[0]]['de'][1], outputs[list(outputs.keys())[1]]['de'][1], outputs[list(outputs.keys())[2]]['de'][1]

('Der Flughafen Aarhus hat eine Fahrbahnlänge von 2702,0.',
 'Der Flughafen Aarhus hat eine Fahrbahnlänge von 2702,0.',
 'Der Flughafen Aarhus hat eine Fahrbahnlänge von 2702.0.')

In [42]:
for model_name in outputs:
    for lang in outputs[model_name]:
        ref = lang_data[lang][1]
        hyp = outputs[model_name][lang]
        with(open(f'/home2/aditya_hari/gsoc/rdf-to-text/src/denoising/outputs/{lang}/{model_name.split("/")[-2]}', 'w')) as f:
            f.write('\n'.join(hyp))
        print(model_name, lang, bleu.corpus_score(hyp, [ref]).score)

/scratch/aditya_hari/gsoc/mt5_pro/ en 20.752479348448915
/scratch/aditya_hari/gsoc/mt5_pro/ de 17.178769427498807
/scratch/aditya_hari/gsoc/mt5_pro/ ga 5.043527832978668
/scratch/aditya_hari/gsoc/mt5_denoised en 20.682504286538787
/scratch/aditya_hari/gsoc/mt5_denoised de 18.017464181172844
/scratch/aditya_hari/gsoc/mt5_denoised ga 5.012456607532915
/scratch/aditya_hari/gsoc/mt5-iter-final/ en 20.849497028402883
/scratch/aditya_hari/gsoc/mt5-iter-final/ de 17.83210463595681
/scratch/aditya_hari/gsoc/mt5-iter-final/ ga 4.838849603775116


In [43]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-distilroberta-v1")
tokenizer_other = AutoTokenizer.from_pretrained("xlm-roberta-base")

In [46]:
model = EncoderDecoderModel.from_pretrained("/scratch/aditya_hari/gsoc/s2s_sbert_van")
tokenizer.add_special_tokens({'additional_special_tokens': ['<TSP>']})
model.encoder.resize_token_embeddings(len(tokenizer))

Embedding(50266, 768, padding_idx=1)

In [47]:
pipe = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer, batch_size=16, device=0, num_beams=5, early_stopping=True)
model_name = 'sbert_van'
for lang in ['en', 'de', 'ga']:
    print(model_name, lang)
    pb = tqdm.tqdm(total=len(lang_data[lang][0]))
    if(model_name not in outputs):
        outputs[model_name] = {}
    outs = [] 
    for out in pipe(Text2TextDataset(lang_data[lang][0])):
        pb.update(1)
        gen_texts = [i['generated_text'] for i in out]
        outs.extend(gen_texts)
    outputs[model_name][lang] = outs

sbert_s2s en


100%|██████████| 1665/1665 [08:53<00:00,  3.12it/s]
100%|█████████▉| 1617/1618 [00:53<00:00, 31.44it/s]

sbert_s2s de


100%|██████████| 1618/1618 [00:53<00:00, 30.44it/s]


sbert_s2s ga


100%|██████████| 869/869 [00:29<00:00, 29.49it/s]
100%|██████████| 1665/1665 [00:55<00:00, 31.17it/s]