## BART for Paraphrasing with Simple Transformers
Paraphrasing is the act of expressing something using different words while retaining the original meaning. 

BART is a denoising autoencoder for pretraining sequence-to-sequence models. BART is trained by (1) corrupting text with an arbitrary noising function, and (2) learning a model to reconstruct the original text.

## BART Sequence-to-Sequence
BART has both an encoder (like BERT) and a decoder (like GPT), essentially getting the best of both worlds.

The encoder uses a denoising objective similar to BERT while the decoder attempts to reproduce the original sequence (autoencoder), token by token, using the previous (uncorrupted) tokens and the output from the encoder.

![BART Seq2Seq](https://miro.medium.com/max/512/1*jWecxbzBsJEbNgiy_LOIvw.png)

## SETUP and INSTALLATIONS

In [None]:
!pip install simpletransformers

## DATA PREPARATION
We will be combining two datasets to serve as training data for our BART Paraphrasing Model.
1. Google PAWS-Wiki Labeled (Final)
2. Quora Question Pairs Dataset


In [2]:
!mkdir data
!wget https://storage.googleapis.com/paws/english/paws_wiki_labeled_final.tar.gz -P data
!tar -xvf data/paws_wiki_labeled_final.tar.gz -C data
!mv data/final/* data
!rm -r data/final

!wget http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv -P data

--2022-01-23 10:28:41--  https://storage.googleapis.com/paws/english/paws_wiki_labeled_final.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.195.128, 108.177.98.128, 74.125.142.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.195.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4687157 (4.5M) [application/gzip]
Saving to: ‘data/paws_wiki_labeled_final.tar.gz’


2022-01-23 10:28:42 (34.5 MB/s) - ‘data/paws_wiki_labeled_final.tar.gz’ saved [4687157/4687157]

final/test.tsv
final/
final/train.tsv
final/dev.tsv
--2022-01-23 10:28:42--  http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv
Resolving qim.fs.quoracdn.net (qim.fs.quoracdn.net)... 151.101.1.2, 151.101.65.2, 151.101.129.2, ...
Connecting to qim.fs.quoracdn.net (qim.fs.quoracdn.net)|151.101.1.2|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58176133 (55M) [text/tab-separated-values]
Saving to: ‘data/quora_duplicate_qu

In [3]:
import warnings

import pandas as pd


def load_data(
    file_path, input_text_column, target_text_column, label_column, keep_label=1
):
    df = pd.read_csv(file_path, sep="\t", error_bad_lines=False)
    df = df.loc[df[label_column] == keep_label]
    df = df.rename(
        columns={input_text_column: "input_text", target_text_column: "target_text"}
    )
    df = df[["input_text", "target_text"]]
    df["prefix"] = "paraphrase"

    return df


def clean_unnecessary_spaces(out_string):
    if not isinstance(out_string, str):
        warnings.warn(f">>> {out_string} <<< is not a string.")
        out_string = str(out_string)
    out_string = (
        out_string.replace(" .", ".")
        .replace(" ?", "?")
        .replace(" !", "!")
        .replace(" ,", ",")
        .replace(" ' ", "'")
        .replace(" n't", "n't")
        .replace(" 'm", "'m")
        .replace(" 's", "'s")
        .replace(" 've", "'ve")
        .replace(" 're", "'re")
    )
    return out_string

In [4]:
import os
from datetime import datetime
import logging

import pandas as pd
from sklearn.model_selection import train_test_split
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

In [5]:
##### Load the datasets #####

# Google Data
train_df = pd.read_csv("data/train.tsv", sep="\t").astype(str)
eval_df = pd.read_csv("data/dev.tsv", sep="\t").astype(str)

train_df = train_df.loc[train_df["label"] == "1"]
eval_df = eval_df.loc[eval_df["label"] == "1"]

train_df = train_df.rename(
    columns={"sentence1": "input_text", "sentence2": "target_text"}
)
eval_df = eval_df.rename(
    columns={"sentence1": "input_text", "sentence2": "target_text"}
)

train_df = train_df[["input_text", "target_text"]]
eval_df = eval_df[["input_text", "target_text"]]

train_df["prefix"] = "paraphrase"
eval_df["prefix"] = "paraphrase"


# Quora Data

# The Quora Dataset is not separated into train/test, so we do it manually the first time.
df = load_data(
    "data/quora_duplicate_questions.tsv", "question1", "question2", "is_duplicate"
)
q_train, q_test = train_test_split(df)

q_train.to_csv("data/quora_train.tsv", sep="\t")
q_test.to_csv("data/quora_test.tsv", sep="\t")

# The code block above only needs to be run once.
# After that, the two lines below are sufficient to load the Quora dataset.

# q_train = pd.read_csv("data/quora_train.tsv", sep="\t")
# q_test = pd.read_csv("data/quora_test.tsv", sep="\t")

train_df = pd.concat([train_df, q_train])
eval_df = pd.concat([eval_df, q_test])

train_df = train_df[["prefix", "input_text", "target_text"]]
eval_df = eval_df[["prefix", "input_text", "target_text"]]

train_df = train_df.dropna()
eval_df = eval_df.dropna()

train_df["input_text"] = train_df["input_text"].apply(clean_unnecessary_spaces)
train_df["target_text"] = train_df["target_text"].apply(clean_unnecessary_spaces)

eval_df["input_text"] = eval_df["input_text"].apply(clean_unnecessary_spaces)
eval_df["target_text"] = eval_df["target_text"].apply(clean_unnecessary_spaces)

print(train_df)

INFO:numexpr.utils:NumExpr defaulting to 2 threads.


            prefix  ...                                        target_text
1       paraphrase  ...  The 1975 -- 76 season of the National Basketba...
3       paraphrase  ...  The results are high when comparable flow rate...
4       paraphrase  ...  It is the seat of the district of Zerendi in A...
5       paraphrase  ...  William Henry Harman was born in Waynesboro, V...
7       paraphrase  ...  Given a discrete set of probabilities formula ...
...            ...  ...                                                ...
389651  paraphrase  ...                Can I own and fly a drone in india?
215502  paraphrase  ...  Why is value of anything power 0 is 1.. i.e. x...
214623  paraphrase  ...  What is the difference between a 32-bit and 64...
31849   paraphrase  ...  Which is the most romantic way to propose a girl?
362650  paraphrase  ...  Where is the best place to learn UI IX designing?

[133776 rows x 3 columns]


## Setup MODEL AND HYPERPARAMETER VALUES

In [None]:
model_args = Seq2SeqArgs()
model_args.do_sample = True
model_args.eval_batch_size = 64
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 2500
model_args.evaluate_during_training_verbose = True
model_args.fp16 = False
model_args.learning_rate = 5e-5
model_args.max_length = 128
model_args.max_seq_length = 128
model_args.num_beams = None
model_args.num_return_sequences = 3
model_args.num_train_epochs = 2
model_args.overwrite_output_dir = True
model_args.reprocess_input_data = True
model_args.save_eval_checkpoints = False
model_args.save_steps = -1
model_args.top_k = 50
model_args.top_p = 0.95
model_args.train_batch_size = 8
model_args.use_multiprocessing = False
model_args.wandb_project = "Paraphrasing with BART"


model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large",
    args=model_args,
)

model.train_model(train_df, eval_data=eval_df)

to_predict = [
    prefix + ": " + str(input_text)
    for prefix, input_text in zip(eval_df["prefix"].tolist(), eval_df["input_text"].tolist())
]
truth = eval_df["target_text"].tolist()

preds = model.predict(to_predict)

# Saving the predictions if needed
os.makedirs("predictions", exist_ok=True)

with open(f"predictions/predictions_{datetime.now()}.txt", "w") as f:
    for i, text in enumerate(eval_df["input_text"].tolist()):
        f.write(str(text) + "\n\n")

        f.write("Truth:\n")
        f.write(truth[i] + "\n\n")

        f.write("Prediction:\n")
        for pred in preds[i]:
            f.write(str(pred) + "\n")
        f.write(
            "________________________________________________________________________________\n"
        )

Downloading:   0%|          | 0.00/1.56k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/971M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

INFO:simpletransformers.seq2seq.seq2seq_utils: Creating features from dataset file at cache_dir/


  0%|          | 0/133776 [00:00<?, ?it/s]

INFO:simpletransformers.seq2seq.seq2seq_model: Training started


Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Running Epoch 0 of 2:   0%|          | 0/16722 [00:00<?, ?it/s]

## 5 sample results

In [None]:
model_args.do_sample = True
model_args.num_beams = None
model_args.num_return_sequences = 3
model_args.max_length = 128
model_args.top_k = 50
model_args.top_p = 0.95

import logging

from simpletransformers.seq2seq import Seq2SeqModel


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

model = Seq2SeqModel(
    encoder_decoder_type="bart", encoder_decoder_name="outputs"
)

input_sentences = [
  "Don't step on the broken glass.",
  "She looked into the mirror and saw another person.",
  "Garlic ice-cream was her favorite.",
  "Be careful with that butter knife.",
  "She was the type of girl who wanted to live in a pink house."

]
for original in input_sentences:
    to_predict = [original]

    preds = model.predict(to_predict)

    print("---------------------------------------------------------")
    print(original)

    print()
    print("Predictions >>>")
    for pred in preds[0]:
        print(pred)

    print("---------------------------------------------------------")
    print()

## REFERENCES

1. https://towardsdatascience.com/bart-for-paraphrasing-with-simple-transformers-7c9ea3dfdd8c
