<a href="https://colab.research.google.com/github/Holy-Morphism/GenAI/blob/main/TextAugmentor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text Augmenter

## Install libraries

In [None]:
!pip install nltk
!pip install git+https://github.com/PrithivirajDamodaran/Parrot_Paraphraser.git

## Import Libraries

In [None]:
import json
import random
from nltk.corpus import wordnet
from nltk.tokenize import word_tokenize
from nltk.tag import pos_tag

# paraphrasing
from parrot import Parrot
import torch
import warnings
warnings.filterwarnings("ignore")

## Install Parrot

In [None]:
parrot = Parrot(model_tag="prithivida/parrot_paraphraser_on_T5")

## Text Augmenter Class

In [None]:
class TextAugmenter:
    def __init__(self, data_path, output_path, parrot):
        self.data_path = data_path
        self.output_path = output_path
        self.data = self.load_data()
        self.parrot = parrot

    def load_data(self):
        with open(self.data_path, "r") as f:
            return json.load(f)

    def save_data(self):
        with open(self.output_path, "w") as f:
            json.dump(self.data, f, indent=4)

    def get_synonyms(self, word):
        """
        Returns a list of synonyms for a given word.
        """
        synonyms = set()
        for syn in wordnet.synsets(word):
            for l in syn.lemmas():
                synonym = l.name().replace("_", " ").replace("-", " ").lower()
                synonyms.add(synonym)
        return list(synonyms)

    def apply_to_data(self, function, **kwargs):
        """
        Applies a given function to the 'text' field of each data item.
        """
        for item in self.data:
            item["text"] = function(item["text"], **kwargs)

    def random_deletion(self, text, p=0.2):
        """
        Randomly deletes words from a sentence with probability p.
        """
        words = text.split()
        if len(words) == 1:
            return text
        new_words = []
        for word in words:
            r = random.uniform(0, 1)
            if r > p:
                new_words.append(word)
        return " ".join(new_words)

    def random_insertion(self, text, p=0.2):
        """
        Randomly inserts synonyms into a sentence with probability p.
        """
        words = text.split()
        new_words = []
        for word in words:
            r = random.uniform(0, 1)
            if r < p:
                synonyms = self.get_synonyms(word)
                if synonyms:
                    new_words.append(random.choice(synonyms))
            new_words.append(word)
        return " ".join(new_words)

    def paraphrase(self, text):
        paraphrases = self.parrot.augment(input_phrase= text,
                               use_gpu=True,
                               do_diverse=True,             # Enable this to get more diverse paraphrases
                               adequacy_threshold = 0.50,   # Lower this numbers if no paraphrases returned
                               fluency_threshold = 0.80)
        if paraphrases:
            return paraphrases[0][0]  # Extract the paraphrased string
        else:
            return text  # Return original text if no paraphrase is generated

    def synonym_replacement(self, text, p=0.2):
        """
        Replaces words with synonyms with probability p.
        """
        words = word_tokenize(text)
        new_words = []
        for word, tag in pos_tag(words):
            r = random.uniform(0, 1)
            if r < p:
                synonyms = self.get_synonyms(word)
                if synonyms:
                    new_words.append(random.choice(synonyms))
                else:
                    new_words.append(word)
            else:
                new_words.append(word)
        return " ".join(new_words)

In [None]:
augmenter = TextAugmenter("announcements.json", "augmented_data.json", parrot)

augmenter.apply_to_data(augmenter.paraphrase)  # Apply random deletion
augmenter.save_data()  # Save intermediate result (optional)

# augmenter.apply_to_data(augmenter.random_deletion, p=0.1)  # Apply random deletion
# augmenter.save_data()  # Save intermediate result (optional)



# augmenter.apply_to_data(augmenter.random_insertion, p=0.1)  # Apply random insertion
# augmenter.save_data()  # Save intermediate result (optional)

# # ... apply other functions in a similar way

# augmenter.apply_to_data(augmenter.synonym_replacement, p=0.2)  # Apply synonym replacement
# augmenter.save_data()  # Save final result