In [None]:
!pip install sentence_transformers==1.2.0
!pip --quiet install --upgrade transformers sentencepiece sentence-splitter strsim num2words sentence-transformers
!pip --quiet install git+https://github.com/LIAAD/yake python-rake sentence-splitter wordtodigits strsim
!pip install ax-platform==0.1.9
!pip install mlrose
!pip install git+https://github.com/boudinfl/pke.git
!pip --quiet install stemming quantulum3 scikit-learn==0.24.2 inflect
!pip install strsim

In [None]:
from abc import ABC, abstractmethod
import torch
from transformers import FSMTForConditionalGeneration, FSMTTokenizer
from sentence_splitter import SentenceSplitter
from nltk.tokenize import word_tokenize, sent_tokenize
from similarity.normalized_levenshtein import NormalizedLevenshtein
import os
from sentence_transformers import SentenceTransformer, util
import re
import pke
import numpy as np
from num2words import num2words
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
import random
import spacy
import wordtodigits
import pickle
import inflect



In [None]:
class Operation(ABC):
    """
    Abstract class for augmenting a given text.
    """

    @abstractmethod
    def generate(self, text, **kwargs):
        """
        Corrupts the given text.
        """
        pass

In [None]:
class BackTranslate(Operation):
    def __init__(self):
        self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        print("self.DEVICE", self.DEVICE)
        forward_tokenizer_path = "/content/drive/MyDrive/resources/models/FSMT/en-de/tokenizer"
        forward_model_path = "/content/drive/MyDrive/resources/models/FSMT/en-de/model"
        if not os.path.exists(forward_tokenizer_path):
            forward_tokenizer_path = forward_model_path = "facebook/wmt19-en-de"
        backward_tokenizer_path = "/content/drive/MyDrive/resources/models/FSMT/de-en/tokenizer"
        backward_model_path = "/content/drive/MyDrive/resources/models/FSMT/de-en/model"
        if not os.path.exists(backward_tokenizer_path):
            backward_tokenizer_path = backward_model_path = "facebook/wmt19-de-en"
        print("[INFO] Loading en-de model...")
        self.forward_tokenizer = FSMTTokenizer.from_pretrained(forward_tokenizer_path)
        self.forward_model = FSMTForConditionalGeneration.from_pretrained(forward_model_path).to(self.DEVICE)
        print("[INFO] Loading de-en model...")
        self.backward_tokenizer = FSMTTokenizer.from_pretrained(backward_tokenizer_path)
        self.backward_model = FSMTForConditionalGeneration.from_pretrained(backward_model_path).to(self.DEVICE)
        self.splitter = SentenceSplitter(language='en')

    def beam_search(self, text, tokenizer, model, num_beams=10):
        input_ids = tokenizer.encode(text, return_tensors="pt").to(self.DEVICE)
        beam_output = model.generate(
            input_ids, 
            num_beams=num_beams, 
            early_stopping=True)
        text = tokenizer.decode(beam_output[0], skip_special_tokens=True)
        return text

    def diverse_beam_search(self, text, tokenizer, model, beam_groups=4, num_return_sequences=4, diversity_penalty=1.5):
        input_ids = tokenizer.encode(text, return_tensors="pt").to(self.DEVICE)
        beam_outputs = model.generate(
            input_ids,
            num_beams = beam_groups*3,
            num_return_sequences = num_return_sequences,
            diversity_penalty = diversity_penalty,
            num_beam_groups = beam_groups)
        text = []
        for beam_output in beam_outputs:
            text.append(tokenizer.decode(beam_output, skip_special_tokens=True))
        return text

    def backtranslate_single(self, text):
        nl = NormalizedLevenshtein()
        t_split = sent_tokenize(text)
        paraphrased = [""]
        for sentence in t_split:
            temp = []
            translated = self.beam_search(sentence, self.forward_tokenizer, self.forward_model, num_beams=8)
            paraphrases = self.diverse_beam_search(translated, self.backward_tokenizer, self.backward_model, num_return_sequences=2, beam_groups=4, diversity_penalty=50.0)
            for paraphrase in paraphrases:
                for prev in paraphrased:
                    temp.append(prev+paraphrase)
            paraphrased = temp
        paraphrased = [(paraphrase, nl.distance(text, paraphrase)) for paraphrase in paraphrased]
        paraphrase = sorted(paraphrased, key=lambda x: x[1], reverse=True)[0]
        return paraphrase[0]

    def generate(self, text, **kwargs):
        return self.backtranslate_single(text)

In [None]:
class DeleteLastSentence(Operation):
    def __init__(self) -> None:
        super().__init__()

    def generate(self, text, **kwargs):
        text_split = sent_tokenize(text)
        if len(text_split) > 1:
            text = " ".join(text_split[:-1])
            return text
        return " ".join(text.split()[:-4]) # Deletes last 4 tokens if there is only 1 sentence

In [None]:
class MostImportantPhraseRemover(Operation):
    def __init__(self) -> None:
        super().__init__()
        self.sentence_transformer = SentenceTransformer("paraphrase-mpnet-base-v2")

    def lose_most_important_word(self, text, soften=False):
        tokenized = text.split()
        original_embedding = self.sentence_transformer.encode([text], convert_to_tensor=True)
        min_score = 1; final_text = ""
        processed_sentences = []
        keyphrases = []

        try:
            self.extractor = pke.unsupervised.TopicRank()
            self.extractor.load_document(input=text, language='en')
            self.extractor.candidate_selection()
            self.extractor.candidate_weighting()
            keyphrases = self.extractor.get_n_best(n=4)
            keyphrases = list(map(lambda x: x[0], keyphrases))
        except:
            pass

        if len(keyphrases) > 0:
            for keyphrase in keyphrases:
                processed_text = text.replace(keyphrase, " ")
                processed_text = re.sub("\s+", " ", processed_text)
                processed_sentences.append(processed_text)
        
        else:
            for idx, word in enumerate(tokenized):
                processed_text = " ".join(tokenized[:idx]+tokenized[idx+1:])
                processed_text = re.sub("\s+", " ", processed_text)
                processed_sentences.append(processed_text)

        processed_embeddings = self.sentence_transformer.encode(processed_sentences, convert_to_tensor=True)
        cosine_scores = util.pytorch_cos_sim(original_embedding, processed_embeddings)
        idx = np.argmin(cosine_scores[0].cpu().numpy())
        return processed_sentences[idx]


    def generate(self, text, **kwargs):
        return self.lose_most_important_word(text)

In [None]:
class Num2Words(Operation):
    def __init__(self) -> None:
        super().__init__()

    def generate(self, text, **kwargs):
        numbers = list(set(re.findall("\d+\.?\d*", text)))
        mapping = {number: num2words(number) for number in numbers}
        for number in mapping:
            text = text.replace(number, mapping[number])
        text = re.sub("\s+", " ", text)
        sentences = sent_tokenize(text)
        text = " ".join(sentence[0].upper() + sentence[1:] for sentence in sentences)
        return text.strip()

In [None]:
class Pegasus(Operation):
    def __init__(self) -> None:
        super().__init__()
        model_name = 'tuner007/pegasus_paraphrase'
        self.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.tokenizer = PegasusTokenizer.from_pretrained(model_name)
        self.model = PegasusForConditionalGeneration.from_pretrained(model_name).to(self.DEVICE)
        self.nl = NormalizedLevenshtein()

    def get_response(self, input_text, num_return_sequences=5, beam_groups=5):
        batch = self.tokenizer(input_text, max_length=128, truncation=True, return_tensors="pt").to(self.DEVICE)
        translated = self.model.generate(**batch, max_length=128, 
                            num_beam_groups = beam_groups, diversity_penalty=5.0, 
                            num_beams=beam_groups*2, num_return_sequences=num_return_sequences, temperature=0.5)
        tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True)
        return tgt_text

    def run_pegasus(self, text, nl=NormalizedLevenshtein(), n=2, soften=False):
        num = n+1 if n>5 else 5
        pegasus_outputs = list(set(self.get_response(text, num_return_sequences=num)))
        pegasus_outputs = [(output, nl.distance(output, text)) for output in pegasus_outputs]
        pegasus_outputs = sorted(pegasus_outputs, key=lambda x: x[1], reverse=True)[:n]
        for pegasus_output, dist in pegasus_outputs:
            if len(pegasus_output.split()) > len(text.split())//2:
                return pegasus_output

    def generate(self, text, **kwargs):
        return self.run_pegasus(text, nl=self.nl)

In [None]:
class RandomDeletion(Operation):
    def __init__(self) -> None:
        super().__init__()

    def generate(self, text, **kwargs):
        soften = kwargs.get("soften", False)
        words = text.split()
        num = np.random.randint(1, max(4, len(words)//10))
        if soften:
            num = 1
        idxs = np.random.randint(0, len(words), size=num)
        text = " ".join([words[i] for i in range(len(words)) if i not in idxs])
        return text

In [None]:
class ReplaceNamedEntities(Operation):
    def __init__(self, resource_dir) -> None:
        super().__init__()
        self.nlp = spacy.load("en_core_web_sm")
        fpaths = [os.path.join(resource_dir, fpath) for fpath in os.listdir(resource_dir) if fpath.endswith(".txt")]
        self.resources = [open(fpath, "r").read().splitlines() for fpath in fpaths]

    def get_replacement(self, entity):
        possibilities = []
        for resource in self.resources:
            if entity.lower() in list(map(lambda x: x.lower(), resource)):
                possibilities.extend(random.sample(resource, 5))
                break
        possibilities = [" {} ".format(x) for x in possibilities]
        possibilities.append(" ")
        return random.choice(possibilities)

    def replace_named_entities(self, text, soften=False):
        """
        loses persons, organizations, products and places
        """
        doc = self.nlp(text)
        named_entities = set(["PERSON", "ORG", "PRODUCT", "EVENT", "GPE", "GEO"])
        ne = []
        for x in doc.ents:
            if x.label_ in named_entities:
                ne.append((x.text, x.start_char, x.end_char))
        if len(ne) == 0:
            return text
        ne_new = random.sample(ne, np.random.randint(1, min(len(ne), 3)+1))
        ne_new = sorted(ne_new, key=lambda x: x[1])
        if soften:
            ne_new = ne_new[:1]
        shift = 0
        for (entity, start, end) in ne_new:
            replacemnt = self.get_replacement(entity)
            text = text[:start-shift] + replacemnt + text[end-shift:]
            shift += end - start - len(replacemnt)
        text = re.sub("\s+", " ", text)
        return text.strip()

    def generate(self, text, **kwargs):
        soften = kwargs.get("soften", False)
        return self.replace_named_entities(text, soften)

In [None]:
class ReplaceNumericalEntities(Operation):
    def __init__(self) -> None:
        super().__init__()

    def generate(self, text, **kwargs):
        possibilities = ["some", "a few", "many", "a lot of", ""]
        replacement = np.random.choice(possibilities)
        text = wordtodigits.convert(text)
        numbers = list(set(re.findall("\d+\.?\d*", text)))
        numbers = np.random.choice(numbers, np.random.randint(1, min(len(numbers)+1, 3)), replace=False)
        for number in numbers:
            text = re.sub(number, " "+replacement+" ", text)
        text = re.sub("\s+", " ", text)
        return text.strip()

In [None]:
class ReplaceUnits(Operation):
    def __init__(self):
        currencies = ["dollar", "cent", "nickel", "penny", "quarter", "dime", "rupee", "paisa", "pound", "euro"]
        currency_abbrv = {}
        currency_symbs = ["$", "£", "₹", "€"]
        currency_symbs_abbrv = {}
        currency_short = ["rs", "USD", "EUR", "GBP"]
        currency_short_abbrv = {}
        length = ["inch", "meter", "metre", "kilometre", "kilometer", "centimeter", "millimeter", "millimetre", "centimetre", "foot", "mile"]
        length_abbrv = {"m": "metre", "km": "kilometre", "cm":"centimetre", "mm":"millimetre", "ft": "foot", "mil": "mile"}
        weight = ["gram", "kilogram"]
        weight_abbrv = {"g": "gram", "gm": "gram", "kg": "kilogram"}
        time = ["minute", "hour", "day", "year", "month", "week"]
        time_abbrv = {"min": "minute", "hr": "hour", "yr": "year", "wk": "week"}
        speed = ["kilometre per hour", "miles per hour", "metre per second"]
        speed_abbrv = {"km/hr": "kilometre per hour", "km/h": "kilometre per hour", "kmph": "kilometre per hour", "mph": "miles per hour", "m/s": "metre per second", "mps": "metre per second"}
        p = inflect.engine()
        self.full = [currencies, currency_symbs, currency_short, length, weight, time, speed]
        self.plurals = [[p.plural(unit) for unit in unit_type] for unit_type in self.full] + ["seconds"]
        self.abbreviated = [currency_abbrv, currency_symbs_abbrv, currency_short_abbrv, length_abbrv, weight_abbrv, time_abbrv, speed_abbrv]
        self.abbreviated_pl = [{p.plural(key): p.plural(value) for (key, value) in ab_dict.items()} for ab_dict in self.abbreviated]
        self.nl = NormalizedLevenshtein()

    def find_closest(self, word, wordlist):
        if word in wordlist:
            return word
        closest = min([(w, self.nl.distance(word, w)) for w in wordlist], key=lambda x: x[1])
        return closest[0]

    def untokenize(self, words):
        """
        ref: https://github.com/commonsense/metanl/blob/master/metanl/token_utils.py#L28
        """
        text = " ".join(words)
        step1 = text.replace("`` ", '"').replace(" ''", '"').replace(". . .", "...")
        step2 = step1.replace(" ( ", " (").replace(" ) ", ") ")
        step3 = re.sub(r' ([.,:;?!%]+)([ \'"`])', r"\1\2", step2)
        step4 = re.sub(r" ([.,:;?!%]+)$", r"\1", step3)
        step5 = step4.replace(" '", "'").replace(" n't", "n't").replace("can not", "cannot")
        step6 = step5.replace(" ` ", " '")
        return step6.strip()

    def abbreviate(self, word, abbrvs):
        closest_word = self.find_closest(word, abbrvs.values())
        for a, w in abbrvs.items():
            if closest_word == w:
                return a
        return word

    def find_replacement_full_index(self, unit, index, negative=True):
        unit_list = self.full[index]
        unit_list_pl = self.plurals[index]
        unit = unit.lower().strip()
        
        if unit in unit_list:
            replacement = np.random.choice(unit_list)
            if negative:
                while replacement == unit or self.nl.distance(unit, replacement)<0.2:
                    replacement = np.random.choice(unit_list)
            else:
                if unit == "rs": return "Rupees"
                replacement = unit
            return replacement
            
        if unit in unit_list_pl:
            replacement = np.random.choice(unit_list_pl)
            if negative:
                while replacement == unit or self.nl.distance(unit, replacement)<0.2:
                    replacement = np.random.choice(unit_list)
            else:
                replacement = unit
            return replacement
        return ""

    def find_replacement_full(self, unit, negative=True):
        for index in range(len(self.full)):
            replacement = self.find_replacement_full_index(unit, index, negative)
            if replacement:
                return replacement
        return ""

    def find_replacement_abbreviated(self, unit, negative=True):
        unit = unit.strip().lower()

        for i, abb_type in enumerate(self.abbreviated):
            abb_type_pl = self.abbreviated_pl[i]
            if unit in abb_type:
                unit_full = abb_type[unit]
                replacement = self.find_replacement_full_index(unit_full, i, negative=negative)
                if negative:
                    replacement = self.abbreviate(replacement, abb_type)
                return replacement

            if unit in abb_type_pl:
                unit_full = abb_type_pl[unit]
                replacement = self.find_replacement_full_index(unit_full, i, negative=negative)
                if negative:
                    replacement = self.abbreviate(replacement, abb_type_pl)
                return replacement

        return ""

    def find_replacement(self, unit, negative=True, abbreviated=False):
        if abbreviated:
            return self.find_replacement_abbreviated(unit, negative)
        replacement = self.find_replacement_full(unit, negative)
        if replacement:
            return replacement
        return self.find_replacement_abbreviated(unit, negative)

    def change(self, sent, negative=True):
        original_sent = sent
        tokenized = word_tokenize(sent)
        replaced = False

        for i, token in enumerate(tokenized):
            quants = re.findall("\d+([a-z/]{1,5})", token)
            if quants:
                if replaced and np.random.choice([True, False], p=[0.65, 0.35]): break
                unit = quants[0]
                replacement = self.find_replacement(unit, abbreviated=True, negative=negative)
                if replacement == "":
                    continue
                token_new = token.replace(unit, replacement)
                tokenized[i] = token_new
                replaced = True    
            else:
                replacement = self.find_replacement(token, negative=negative)
                if replacement:
                    if replaced and np.random.choice([True, False], p=[0.65, 0.35]): break
                    tokenized[i] = replacement
                    replaced = True
                else:
                    continue
        
        sent = self.untokenize(tokenized)
        if sent == original_sent:
            sent = ""
        return sent

    def generate(self, text, **kwargs):
        return self.change(text, negative=True)

In [None]:
class SameSentence(Operation):
    def __init__(self) -> None:
        super().__init__()

    def generate(self, text, **kwargs):
        return text

In [None]:
class TF_IDF_Replacement(Operation):
    def __init__(self, resource_dir) -> None:
        super().__init__()
        tfidf_path = os.path.join(resource_dir, "tfidf_aqua.pkl")
        self.tfidf = pickle.load(open(tfidf_path, "rb"))
        self.words = self.tfidf.get_feature_names()

    def __sample(self, n=3):
        return random.sample(self.words, n)

    def generate(self, text, **kwargs):
        soften = kwargs.get("soften", False)
        
        transformed = self.tfidf.transform([text]).toarray()
        most_imp = np.argpartition(transformed, -4)[:, -4:]
        array = most_imp[0]
        question = text
        vals = []
        num_replace = np.random.randint(1, 3)
        if soften:
            num_replace = 1
        replacements = self.__sample(num_replace)
        for idx in array:
            val = transformed[0][idx]
            word = self.words[idx]
            vals.append((val, word))
        vals.sort(reverse = True)
        replaced = list(map(lambda x: x[1], vals))[:num_replace]
        for replaced_, replacement in zip(replaced, replacements):
            question = question.replace(replaced_, replacement, 1)
        return question

In [None]:
class UnitExpansion(Operation):
    def __init__(self):
        currencies = ["dollar", "cent", "nickel", "penny", "quarter", "dime", "rupee", "paisa", "pound", "euro"]
        currency_abbrv = {}
        currency_symbs = ["$", "£", "₹", "€"]
        currency_symbs_abbrv = {}
        currency_short = ["rs", "USD", "EUR", "GBP"]
        currency_short_abbrv = {}
        length = ["inch", "meter", "metre", "kilometre", "kilometer", "centimeter", "millimeter", "millimetre", "centimetre", "foot", "mile"]
        length_abbrv = {"m": "metre", "km": "kilometre", "cm":"centimetre", "mm":"millimetre", "ft": "foot", "mil": "mile"}
        weight = ["gram", "kilogram"]
        weight_abbrv = {"g": "gram", "gm": "gram", "kg": "kilogram"}
        time = ["minute", "hour", "day", "year", "month", "week"]
        time_abbrv = {"min": "minute", "hr": "hour", "yr": "year", "wk": "week"}
        speed = ["kilometre per hour", "miles per hour", "metre per second"]
        speed_abbrv = {"km/hr": "kilometre per hour", "km/h": "kilometre per hour", "kmph": "kilometre per hour", "mph": "miles per hour", "m/s": "metre per second", "mps": "metre per second"}
        p = inflect.engine()
        self.full = [currencies, currency_symbs, currency_short, length, weight, time, speed]
        self.plurals = [[p.plural(unit) for unit in unit_type] for unit_type in self.full]
        self.abbreviated = [currency_abbrv, currency_symbs_abbrv, currency_short_abbrv, length_abbrv, weight_abbrv, time_abbrv, speed_abbrv]
        self.abbreviated_pl = [{p.plural(key): p.plural(value) for (key, value) in ab_dict.items()} for ab_dict in self.abbreviated]
        self.nl = NormalizedLevenshtein()

    def find_closest(self, word, wordlist):
        if word in wordlist:
            return word
        closest = min([(w, self.nl.distance(word, w)) for w in wordlist], key=lambda x: x[1])
        return closest[0]

    def untokenize(self, words):
        """
        ref: https://github.com/commonsense/metanl/blob/master/metanl/token_utils.py#L28
        """
        text = " ".join(words)
        step1 = text.replace("`` ", '"').replace(" ''", '"').replace(". . .", "...")
        step2 = step1.replace(" ( ", " (").replace(" ) ", ") ")
        step3 = re.sub(r' ([.,:;?!%]+)([ \'"`])', r"\1\2", step2)
        step4 = re.sub(r" ([.,:;?!%]+)$", r"\1", step3)
        step5 = step4.replace(" '", "'").replace(" n't", "n't").replace("can not", "cannot")
        step6 = step5.replace(" ` ", " '")
        return step6.strip()

    def abbreviate(self, word, abbrvs):
        closest_word = self.find_closest(word, abbrvs.values())
        for a, w in abbrvs.items():
            if closest_word == w:
                return a
        return word

    def find_replacement_full_index(self, unit, index, negative=True):
        unit_list = self.full[index]
        unit_list_pl = self.plurals[index]
        unit = unit.lower().strip()
        
        if unit in unit_list:
            replacement = np.random.choice(unit_list)
            if negative:
                while replacement == unit or self.nl.distance(unit, replacement)<0.2:
                    replacement = np.random.choice(unit_list)
            else:
                if unit == "rs": return "Rupees"
                replacement = unit
            return replacement
            
        if unit in unit_list_pl:
            replacement = np.random.choice(unit_list_pl)
            if negative:
                while replacement == unit or self.nl.distance(unit, replacement)<0.2:
                    replacement = np.random.choice(unit_list)
            else:
                replacement = unit
            return replacement
        return ""

    def find_replacement_full(self, unit, negative=True):
        for index in range(len(self.full)):
            replacement = self.find_replacement_full_index(unit, index, negative)
            if replacement:
                return replacement
        return ""

    def find_replacement_abbreviated(self, unit, negative=True):
        unit = unit.strip().lower()

        for i, abb_type in enumerate(self.abbreviated):
            abb_type_pl = self.abbreviated_pl[i]
            if unit in abb_type:
                unit_full = abb_type[unit]
                replacement = self.find_replacement_full_index(unit_full, i, negative=negative)
                if negative:
                    replacement = self.abbreviate(replacement, abb_type)
                return replacement

            if unit in abb_type_pl:
                unit_full = abb_type_pl[unit]
                replacement = self.find_replacement_full_index(unit_full, i, negative=negative)
                if negative:
                    replacement = self.abbreviate(replacement, abb_type_pl)
                return replacement

        return ""

    def find_replacement(self, unit, negative=True, abbreviated=False):
        if abbreviated:
            return self.find_replacement_abbreviated(unit, negative)
        replacement = self.find_replacement_full(unit, negative)
        if replacement:
            return replacement
        return self.find_replacement_abbreviated(unit, negative)

    def change(self, sent, negative=True):
        original_sent = sent
        tokenized = word_tokenize(sent)
        replaced = False

        for i, token in enumerate(tokenized):
            quants = re.findall("\d+([a-z/]{1,5})", token)
            if quants:
                if replaced and np.random.choice([True, False], p=[0.65, 0.35]): break
                unit = quants[0]
                replacement = self.find_replacement(unit, abbreviated=True, negative=negative)
                if replacement == "":
                    continue
                token_new = token.replace(unit, replacement)
                tokenized[i] = token_new
                replaced = True    
            else:
                replacement = self.find_replacement(token, negative=negative)
                if replacement:
                    if replaced and np.random.choice([True, False], p=[0.65, 0.35]): break
                    tokenized[i] = replacement
                    replaced = True
                else:
                    continue
        
        sent = self.untokenize(tokenized)
        if sent == original_sent:
            sent = ""
        return sent

    def generate(self, text, **kwargs):
        return self.change(text, negative=False)

In [None]:
class CombinedOperation(ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def generate(self, text, ops=[], **kwargs):
        pass

In [None]:
class PositiveSamples(CombinedOperation):
    def __init__(self):
        self.operations = {
            "BackTranslate": BackTranslate(),
            "SameSentence": SameSentence(),
            "Num2Words": Num2Words(),
            "UnitExpansion": UnitExpansion()
        }

    def generate(self, text, ops=["BackTranslate", "SameSentence", "Num2Words", "UnitExpansion"]):
        positives = []
        for op in ops:
            if op not in self.operations: continue
            operator = self.operations[op]
            positives.append(operator.generate(text))
        return positives

In [None]:
class NegativeSamples(CombinedOperation):
    def __init__(self):
        super().__init__()
        self.operations = {
            "MostImportantPhraseRemover": MostImportantPhraseRemover(),
            "DeleteLastSentence": DeleteLastSentence(),
            "ReplaceNamedEntities": ReplaceNamedEntities(),
            "ReplaceNumericalEntities": ReplaceNumericalEntities(),
            "Pegasus": Pegasus(),
            "ReplaceUnits": ReplaceUnits(),
            # "NegateQuestion": NegateQuestion(),
        }
        self.backup = TF_IDF_Replacement()

    def operate(self, text, operator, soften=False, c=0):
        initial_text = text
        initial_wordlen = len(text.split())
        text = operator.generate(text, soften=soften)
  
        if text is None or len(text)<2 or text == initial_text:
            text = self.backup.generate(initial_text, soften=soften)
            return text

        if text and len(text.split()) < initial_wordlen//2 and c < 5:
            c += 1
            text = self.lose_information_single(initial_text, operator, soften=True, c=c)

        return text

    def generate(self, text, ops=["MostImportantPhraseRemover", "DeleteLastSentence", "ReplaceNamedEntities",
                                  "ReplaceNumericalEntities", "Pegasus", "ReplaceUnits", "NegateQuestion"], 
                                  **kwargs):
        negatives = []
        for op in ops:
            if op not in self.operations: continue
            operator = self.operations[op]
            negatives.append(self.operate(text, operator))
        return negatives 

In [None]:
class DataGenerator(object):
    def __init__(self):
        self.positive_generator = PositiveSamples()
        self.negative_generator = NegativeSamples()

    def generate(self, text, positive_ops=["BackTranslate", "SameSentence", "Num2Words", "UnitExpansion"], 
                negative_ops=["MostImportantPhraseRemover", "DeleteLastSentence", "ReplaceNamedEntities",
                            "ReplaceNumericalEntities", "Pegasus", "ReplaceUnits", "NegateQuestion"]):
        positives = self.positive_generator.generate(text, ops=positive_ops)
        negatives = self.negative_generator.generate(text, ops=negative_ops)
        return positives, negatives